From c81b2dabbc45484dee2ca6658cfe39c841df5c70 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Thu, 7 May 2026 13:28:18 +0900 Subject: [PATCH 001/289] ruby : transcribe without GVL, accept more MemoryViews, Windows support, fix memory size report, improve document (#3775) * Change MemoryView example using NDAV * Add note on audio attributes for #full and #full_parallel * Support more variants of MemoryView * Use IO.popen instead of Kernel.` for Windows compatibility * Use cmake's -C option instead of multiple -D options * Fix memsize calculation * Remove unused argument * Add is_interrupted field to abort callback container * Fix RBS syntax * Address document comment for RDoc * Add .document for RDoc * Add .rdoc_options * Run #full without GVL * Initialize callbacks with nil * Specify implicity Whisper::Params to distinguish from Whisper::Context::Params * Run callbacks without GVL * Call log callback with GVL * Run full_parallel without GVL * Run transcribe without GVL * Fix ruby_whisper_lock_gvl and ruby_whisper_unlock_gvl * Fix return value of encoder_begin_callback * Report GVL unlocking from transcribe * Remove unused interface * Restore overload of full_parallel * Close process * Fix struct name * Make is_without_gvl thread local * Use rb_thread_call_with_gvl instead of global variable * Retrieve instance variable in GVL * Narrow acceptable MemoryView format * Fix option cache path * Reduce files in package * Use append_cflags * Add ext/*.rb to task dependencies * Use copy instead of cp * Make TestPackage more portable * Patch for lower version Ruby * Make build scripts more portable * Add Windows support * Don't raise exceptions --- bindings/ruby/.document | 3 + bindings/ruby/.rdoc_options | 2 + bindings/ruby/README.md | 10 +- bindings/ruby/Rakefile | 4 +- bindings/ruby/ext/dependencies.rb | 14 +- bindings/ruby/ext/dependencies_for_windows.rb | 17 ++ bindings/ruby/ext/extconf.rb | 28 +- bindings/ruby/ext/options.rb | 68 ++++- bindings/ruby/ext/options_for_windows.rb | 51 ++++ bindings/ruby/ext/ruby_whisper.c | 36 ++- bindings/ruby/ext/ruby_whisper.h | 17 +- bindings/ruby/ext/ruby_whisper_context.c | 102 ++++++- bindings/ruby/ext/ruby_whisper_params.c | 261 +++++++++++++++--- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 47 +++- bindings/ruby/extsources.rb | 36 ++- bindings/ruby/sig/whisper.rbs | 100 +++---- bindings/ruby/test/test_package.rb | 11 +- 17 files changed, 647 insertions(+), 160 deletions(-) create mode 100644 bindings/ruby/.document create mode 100644 bindings/ruby/.rdoc_options create mode 100644 bindings/ruby/ext/dependencies_for_windows.rb create mode 100644 bindings/ruby/ext/options_for_windows.rb diff --git a/bindings/ruby/.document b/bindings/ruby/.document new file mode 100644 index 00000000000..a8e9788fc7c --- /dev/null +++ b/bindings/ruby/.document @@ -0,0 +1,3 @@ +README.md +LICENSE +sig diff --git a/bindings/ruby/.rdoc_options b/bindings/ruby/.rdoc_options new file mode 100644 index 00000000000..cf14aa5f5b4 --- /dev/null +++ b/bindings/ruby/.rdoc_options @@ -0,0 +1,2 @@ +title: whispercpp +main_page: README.md diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 41e7b330d58..07b81830c58 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -360,7 +360,7 @@ Whisper::Context.new("base") ### Low-level API to transcribe ### -You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. +You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. Unlike `#transcribe`, these methods requires 16,000 Hz, 32-bit float audio. ```ruby require "whisper" @@ -383,16 +383,16 @@ If you can prepare audio data as C array and export it as a MemoryView, whisperc ```ruby require "torchaudio" -require "arrow-numo-narray" +require "ndav/torch/tensor" require "whisper" waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") -# Convert Torch::Tensor to Arrow::Array via Numo::NArray -samples = waveform.squeeze.numo.to_arrow.to_arrow_array +# Convert Torch::Tensor to NDAV +samples = waveform.squeeze.to_ndav whisper = Whisper::Context.new("base") whisper - # Arrow::Array exports MemoryView + # NDAV exports MemoryView .full(Whisper::Params.new, samples) ``` diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index d9a66030de4..7b521b3bdfa 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -16,7 +16,7 @@ EXTSOURCES.each do |src| file src directory dir file dest => [src, dir] do |t| - cp t.source, t.name + copy t.source, t.name end SOURCES.include dest end @@ -34,7 +34,7 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) -file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t| +file "ext/Makefile" => SRC + SOURCES + FileList["ext/*.rb"] do |t| chdir "ext" do ruby "extconf.rb" end diff --git a/bindings/ruby/ext/dependencies.rb b/bindings/ruby/ext/dependencies.rb index 2ba4b94b62b..b2eb9beb84f 100644 --- a/bindings/ruby/ext/dependencies.rb +++ b/bindings/ruby/ext/dependencies.rb @@ -22,13 +22,17 @@ def libs else nil end - }.reverse.collect {|lib| "lib#{lib}.a"} + }.reverse.collect {|lib| "#{prefix(lib)}#{lib}.#{RbConfig::CONFIG['LIBEXT']}"} end def to_s libs.join(" ") end + def local_libs + to_s + end + private def dot_path @@ -36,9 +40,7 @@ def dot_path end def generate_dot - args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"] - args << @options.to_s unless @options.to_s.empty? - system @cmake, *args, exception: true + system @cmake, "-S", "sources", "-B", "build", *@options.graphviz_cmake_args, "--graphviz", dot_path, *@options, exception: true end def parse_dot @@ -59,6 +61,10 @@ def parse_dot end end + def prefix(lib) + "lib" + end + def tsort_each_node @nodes.each_key do |node| yield node diff --git a/bindings/ruby/ext/dependencies_for_windows.rb b/bindings/ruby/ext/dependencies_for_windows.rb new file mode 100644 index 00000000000..5574107182d --- /dev/null +++ b/bindings/ruby/ext/dependencies_for_windows.rb @@ -0,0 +1,17 @@ +require_relative "dependencies" + +class DependenciesForWindows < Dependencies + def local_libs + libs.collect {|lib| %|"#{lib_path(lib)}"|}.join(" ") + end + + private + + def prefix(lib) + lib.start_with?("ggml") ? "" : "lib" + end + + def lib_path(lib) + File.join(__dir__, lib).tr("\\", "/") + end +end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index acff501aa3b..4b09b6ebe13 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,15 +1,27 @@ require "mkmf" -require_relative "options" -require_relative "dependencies" + +if RUBY_PLATFORM.match? /mswin|mingw|ucrt/ + require_relative "options_for_windows" + require_relative "dependencies_for_windows" + + Opts = OptionsForWindows + Deps = DependenciesForWindows +else + require_relative "options" + require_relative "dependencies" + + Opts = Options + Deps = Dependencies +end cmake = find_executable("cmake") || abort -options = Options.new(cmake).to_s +options = Opts.new(cmake) have_library("gomp") rescue nil -libs = Dependencies.new(cmake, options).to_s +libs = Deps.new(cmake, options) -$CFLAGS << " -O3 -march=native" +append_cflags ["-O3", "-march=native"] $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" -$LOCAL_LIBS << " #{libs}" +$LOCAL_LIBS << " #{libs.local_libs}" $cleanfiles << " build #{libs}" create_makefile "whisper" do |conf| @@ -17,7 +29,7 @@ $(TARGET_SO): #{libs} #{libs}: cmake-targets cmake-targets: - #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} - #{"\t"}#{cmake} --build build --config Release --target common whisper + #{"\t"}"#{cmake}" -S sources -B build #{options} + #{"\t"}"#{cmake}" --build build --config Release --target common whisper EOF end diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index ede80c0656b..e723af9fd9a 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -1,26 +1,36 @@ +require "fileutils" + class Options def initialize(cmake="cmake") @cmake = cmake @options = {} configure + write_cache_file + end + + def to_a + [ + "-D", "BUILD_SHARED_LIBS=OFF", + "-D", "WHISPER_BUILD_TESTS=OFF", + "-D", "CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__}", + "-D", "CMAKE_POSITION_INDEPENDENT_CODE=ON", + "-C", cache_path + ] end def to_s - @options - .reject {|name, (type, value)| value.nil?} - .collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"} - .join(" ") + command_line(*to_a) end - def cmake_options - return @cmake_options if @cmake_options + def graphviz_cmake_args + [] + end - output = nil - Dir.chdir __dir__ do - output = `#{@cmake.shellescape} -S sources -B build -L` - end - @cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) + private + + def cmake_options + @cmake_options ||= cmake_options_output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) .filter_map {|line| option, value = line.chomp.split("=", 2) name, type = option.split(":", 2) @@ -34,7 +44,11 @@ def cmake_options }.to_h end - private + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", "-L"]) {|io| io.read} + end + end def configure cmake_options.each_pair do |name, (type, default_value)| @@ -74,12 +88,38 @@ def option_name(name) def enabled?(option) op = @options[option] - raise "Option not exist: #{option}" unless op - raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL" + return false unless op + return false unless op[0] == "BOOL" if op[1].nil? cmake_options[option][1] else op[1] end end + + def cache_path + File.join(__dir__, "sources", "Options.cmake") + end + + def write_cache_file + FileUtils.mkpath File.dirname(cache_path) + File.open cache_path, "w" do |file| + @options.reject {|name, (type, value)| value.nil?}.each do |name, (type, value)| + line = "set(CACHE{%s} TYPE %s FORCE VALUE %s)" % { + name:, + type:, + value: value == true ? "ON" : value == false ? "OFF" : escape_cmake(value) + } + file.puts line + end + end + end + + def escape_cmake(str) + str.gsub(/[\\"]/, '\\\\\&') + end + + def command_line(*args) + args.collect {|arg| %|"#{arg.to_s.gsub(/[\\"]/, '\\\\\&')}"|}.join(" ") + end end diff --git a/bindings/ruby/ext/options_for_windows.rb b/bindings/ruby/ext/options_for_windows.rb new file mode 100644 index 00000000000..7db785d8a2d --- /dev/null +++ b/bindings/ruby/ext/options_for_windows.rb @@ -0,0 +1,51 @@ +require_relative "options" + +class OptionsForWindows < Options + def to_s + command_line(*generator_args, *to_a) + end + + def graphviz_cmake_args + generator_args + end + + private + + def arm? + RbConfig::CONFIG["host_cpu"].to_s.downcase.match?(/\A(?:arm64|aarch64)\z/) + end + + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", *generator_args, "-L"]) {|io| io.read} + end + end + + def generator_args + generator = cmake_generator + ["-G", generator] if generator && !generator.empty? + end + + def cmake_generator + return @cmake_generator if defined?(@cmake_generator) + + generator = ENV["CMAKE_GENERATOR"] + abort "CMAKE_GENERATOR=#{generator} is unsupported for mingw/ucrt Ruby" if visual_studio_generator_name?(generator) + return @cmake_generator = generator unless generator.nil? || generator.empty? + + ninja = find_executable("ninja") + return @cmake_generator = "Ninja" if ninja + + make = find_executable("make") + return @cmake_generator = "MSYS Makefiles" if make + + mingw32_make = find_executable("mingw32-make") + return @cmake_generator = "MinGW Makefiles" if mingw32_make + + @cmake_generator = nil + end + + def visual_studio_generator_name?(generator) + generator && generator.start_with?("Visual Studio") + end +end diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 5f1917ee805..56fceb1c894 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -29,6 +29,7 @@ ID id_cache; ID id_n_processors; static bool is_log_callback_finalized = false; +static bool is_ruby_log_callback_present = false; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -106,18 +107,43 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { return Qnil; } +typedef struct { + int level; + const char * buffer; +} call_log_callbacks_args; + +static void* +call_log_callbacks(void *v_args) { + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return NULL; + } + + call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; + VALUE user_data = rb_iv_get(mWhisper, "user_data"); + rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); + + return NULL; +} + static void ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { if (is_log_callback_finalized) { return; } - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { + if (!is_ruby_log_callback_present) { return; } - VALUE udata = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); + call_log_callbacks_args args = { + level, + buffer, + }; + if (ruby_thread_has_gvl_p()) { + call_log_callbacks((void *)&args); + } else { + rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); + } } /* @@ -140,8 +166,10 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d if (NIL_P(log_callback)) { whisper_log_set(NULL, NULL); + is_ruby_log_callback_present = false; } else { whisper_log_set(ruby_whisper_log_callback, NULL); + is_ruby_log_callback_present = true; } return Qnil; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 6b0b4df7214..ba4d8b6fbcc 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -2,10 +2,17 @@ #define RUBY_WHISPER_H #include +#include #include +#include #include #include "whisper.h" +#if RUBY_API_VERSION_MAJOR < 4 +// Exists but not declared as public API +int ruby_thread_has_gvl_p(void); +#endif + typedef struct { VALUE *context; VALUE user_data; @@ -13,6 +20,14 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; + VALUE callbacks; + bool is_interrupted; +} ruby_whisper_abort_callback_container; + typedef struct { struct whisper_context *context; } ruby_whisper; @@ -27,7 +42,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_callback_container *abort_callback_container; + ruby_whisper_abort_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 6e38ead6321..26058fc07e6 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,5 +1,11 @@ #include "ruby_whisper.h" +#ifdef WORDS_BIGENDIAN + #define IS_BIGENDIAN true +#else + #define IS_BIGENDIAN false +#endif + extern ID id_to_s; extern ID id___method__; extern ID id_to_enum; @@ -47,6 +53,27 @@ typedef struct full_parallel_args { int n_processors; } full_parallel_args; +typedef struct full_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int result; +} full_without_gvl_args; + +typedef struct full_parallel_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int n_processors; + int result; +} full_parallel_without_gvl_args; + +typedef struct full_ubf_args { + ruby_whisper_abort_callback_container *abort_callback_container; +} full_ubf_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -74,7 +101,7 @@ static size_t ruby_whisper_memsize(const void *p) { const ruby_whisper *rw = (const ruby_whisper *)p; - size_t size = sizeof(rw); + size_t size = sizeof(*rw); if (!rw) { return 0; } @@ -304,11 +331,25 @@ VALUE ruby_whisper_model_type(VALUE self) static bool check_memory_view(rb_memory_view_t *memview) { - if (memview->format != NULL && strcmp(memview->format, "f") != 0) { - rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + if (!memview->format) { + rb_warn("currently format is required"); + return false; + } + + if (strcmp(memview->format, "f") == 0) { + // accept + } else if (strcmp(memview->format, "e") == 0) { + if (IS_BIGENDIAN) { + rb_warn("currently format \"e\" is only supported on little-endian environment"); + return false; + } + } else { + rb_warn("currently only format \"f\" and \"e\" on little-endian environment is supported for MemoryView, but given: %s", memview->format); return false; } - if (memview->format != NULL && memview->ndim != 1) { + + if (memview->ndim != 1 && !(memview->ndim == 2 && memview->shape[1] == 1)) { + // TODO: Accept ndim == 2 with shape [n_samples, channels] and channels > 1 by averaging the samples in different channels or just taking the first channel rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); return false; } @@ -426,6 +467,22 @@ release_samples(VALUE rb_parsed_args) return Qnil; } +static void* +full_without_gvl(void *rb_args) +{ + full_without_gvl_args *args = (full_without_gvl_args *)rb_args; + args->result = whisper_full(args->context, *args->params, args->samples, args->n_samples); + return NULL; +} + +static void +full_ubf(void *rb_args) +{ + full_ubf_args *args = (full_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} + static VALUE full_body(VALUE rb_args) { @@ -437,9 +494,19 @@ full_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, 1); - int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); - return INT2NUM(result); + struct full_without_gvl_args full_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_without_gvl_args.result); } /* @@ -477,6 +544,14 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static void* +full_parallel_without_gvl(void *rb_args) +{ + full_parallel_without_gvl_args *args = (full_parallel_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + return NULL; +} + static VALUE full_parallel_body(VALUE rb_args) { @@ -488,9 +563,20 @@ full_parallel_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, args->n_processors); - int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); - return INT2NUM(result); + struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + args->n_processors, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_parallel_without_gvl_args.result); } /* diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 3e5dca9c1e1..2aae7c12d19 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -93,21 +93,66 @@ rb_whisper_callback_container_allocate() { container->context = NULL; container->user_data = Qnil; container->callback = Qnil; - container->callbacks = rb_ary_new(); + container->callbacks = Qnil; return container; } -static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; +static void +rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) +{ + if (rwc == NULL) return; + + rb_gc_mark(rwc->user_data); + rb_gc_mark(rwc->callback); + rb_gc_mark(rwc->callbacks); +} + +static ruby_whisper_abort_callback_container* +rb_whisper_abort_callback_container_allocate() { + ruby_whisper_abort_callback_container *container; + container = ALLOC(ruby_whisper_abort_callback_container); + container->context = NULL; + container->user_data = Qnil; + container->callback = Qnil; + container->callbacks = Qnil; + container->is_interrupted = false; + return container; +} + +static bool +ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +static bool +ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int n_new; +} call_new_segment_callbacks_args; + +static void* +call_new_segment_callbacks(void *v_args) { + call_new_segment_callbacks_args *args = (call_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + struct whisper_state *state = args->state; + int n_new = args->n_new; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } const int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; i--) { @@ -118,95 +163,208 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta rb_funcall(cb, id_call, 1, segment); } } + + return NULL; } -static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_new_segment_callbacks_args args = { + container, + state, + n_new + }; + rb_thread_call_with_gvl(call_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int progress_cur; +} call_progress_callbacks_args; + +static void* +call_progress_callbacks(void *v_args) { + call_progress_callbacks_args *args = (call_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + int progress_cur = args->progress_cur; + + // Currently, doesn't support state because // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + if (!NIL_P(args->container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(progress_cur), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); + rb_funcall(cb, id_call, 1, INT2NUM(progress_cur)); } + + return NULL; } -static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { +static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - bool is_aborted = false; - VALUE result; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_progress_callbacks_args args = { + container, + state, + progress_cur + }; + rb_thread_call_with_gvl(call_progress_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + bool is_continued; +} call_encoder_begin_callbacks_args; + +static void* +call_encoder_begin_callbacks(void *v_args) { + call_encoder_begin_callbacks_args *args = (call_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); if (result == Qfalse) { - is_aborted = true; + args->is_continued = false; + return NULL; } } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return !is_aborted; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - result = rb_funcall(cb, id_call, 0); - if (result == Qfalse) { - is_aborted = true; + if (!NIL_P(container->callbacks)) { + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return NULL; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } } } - return !is_aborted; + + return NULL; } -static bool abort_callback(void * user_data) { +static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_encoder_begin_callbacks_args args = { + container, + state, + true + }; + rb_thread_call_with_gvl(call_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_abort_callback_container *container; + struct whisper_state *state; + bool is_interrupted; +} call_abort_callbacks_args; + +static void* +call_abort_callbacks(void *v_args) { + call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; + const ruby_whisper_abort_callback_container *container = args->container; + + if (container->is_interrupted) { + args->is_interrupted = true; + return NULL; + } + if (!NIL_P(container->callback)) { VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return false; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); VALUE result = rb_funcall(cb, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } - return false; + + return NULL; +} + +static bool abort_callback(void * user_data) { + const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + + if (container->is_interrupted) { + return true; + } + + if (!ruby_whisper_abort_callback_container_is_present(container)) { + return false; + } + + call_abort_callbacks_args args = { + container, + NULL, + false + }; + rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); + + return args.is_interrupted; } static void -check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +check_thread_safety(ruby_whisper_params *rwp, int n_processors) { if (n_processors == 1) { return; } - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } @@ -217,29 +375,28 @@ check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) } static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rwp->progress_callback_container->context = context; rwp->params.progress_callback = progress_callback; rwp->params.progress_callback_user_data = rwp->progress_callback_container; } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rwp->encoder_begin_callback_container->context = context; rwp->params.encoder_begin_callback = encoder_begin_callback; rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->abort_callback_container->context = context; - rwp->params.abort_callback = abort_callback; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } + rwp->abort_callback_container->context = context; + rwp->params.abort_callback = abort_callback; + rwp->abort_callback_container->is_interrupted = false; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; } static void set_vad_params(ruby_whisper_params *rwp) @@ -255,7 +412,7 @@ static void set_vad_params(ruby_whisper_params *rwp) void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { - check_thread_safety(rwp, context, n_processors); + check_thread_safety(rwp, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -267,7 +424,7 @@ rb_whisper_params_mark(void *p) rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_callbcack_container_mark(rwp->abort_callback_container); + rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -338,7 +495,7 @@ ruby_whisper_params_allocate(VALUE klass) rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_callback_container_allocate(); + rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); return obj; } @@ -1302,6 +1459,9 @@ ruby_whisper_params_on_new_segment(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); return Qnil; } @@ -1322,6 +1482,9 @@ ruby_whisper_params_on_progress(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->progress_callback_container->callbacks, blk); return Qnil; } @@ -1342,6 +1505,9 @@ ruby_whisper_params_on_encoder_begin(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->encoder_begin_callback_container->callbacks)) { + rwp->encoder_begin_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk); return Qnil; } @@ -1366,6 +1532,9 @@ ruby_whisper_params_abort_on(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->abort_callback_container->callbacks, blk); return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 3d00566009a..37656af1c44 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -15,8 +15,37 @@ extern ID id_call; extern ID id_to_path; extern ID transcribe_option_names[1]; -extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); + +typedef struct{ + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + size_t n_samples; + int n_processors; + int result; +} transcribe_without_gvl_args; + +static void* +transcribe_without_gvl(void *rb_args) +{ + transcribe_without_gvl_args *args = (transcribe_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_container *abort_callback_container; +} transcribe_ubf_args; + +static void +transcribe_ubf(void *rb_args) +{ + transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} /* * transcribe a single file @@ -75,7 +104,19 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { prepare_transcription(rwp, &self, n_processors); - if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { + transcribe_without_gvl_args args = { + rw->context, + &rwp->params, + pcmf32.data(), + pcmf32.size(), + n_processors, + 0, + }; + transcribe_ubf_args ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); + if (args.result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/extsources.rb b/bindings/ruby/extsources.rb index b24f1a7f13d..850ac9841b1 100644 --- a/bindings/ruby/extsources.rb +++ b/bindings/ruby/extsources.rb @@ -5,37 +5,53 @@ .devops .github ci - examples/wchess/wchess.wasm + examples/addon.node + examples/bench.wasm + examples/command + examples/command.wasm + examples/lsp + examples/main + examples/python + examples/stream + examples/stream.wasm + examples/sycl + examples/talk-llama + examples/wchess examples/whisper.android examples/whisper.android.java + examples/whisper.nvim examples/whisper.objc examples/whisper.swiftui + examples/whisper.wasm grammars models samples scripts + tests ].collect {|dir| root/dir} ignored_files = %w[ AUTHORS Makefile - README.md - README_sycl.md .gitignore .gitmodules .dockerignore - whisper.nvim - twitch.sh - yt-wsp.sh - close-issue.yml - build-xcframework.sh +] +ignored_exts = %w[ + .yml + .sh + .md + .py + .js + .nvim ] EXTSOURCES = `git ls-files -z #{root}`.split("\x0") .collect {|file| Pathname(file)} .reject {|file| - ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + ignored_exts.include?(file.extname) || ignored_files.include?(file.basename.to_path) || - (file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript") + ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + (file.descend.to_a[1] != root && file != Pathname("..")/"javascript"/"package-tmpl.json") } .collect(&:to_path) diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 3c59661975b..cbec4803820 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -5,10 +5,10 @@ module Whisper end type log_callback = ^(Integer level, String message, Object user_data) -> void - type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void - type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void - type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void - type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish + type new_segment_callback = ^(Whisper::Context, untyped, Integer n_new, Object user_data) -> void + type progress_callback = ^(Whisper::Context, untyped, Integer progress, Object user_data) -> void + type encoder_begin_callback = ^(Whisper::Context, untyped, Object user_data) -> void + type abort_callback = ^(Whisper::Context, untyped, Object user_data) -> boolish VERSION: String LOG_LEVEL_NONE: Integer @@ -52,11 +52,11 @@ module Whisper # puts text # end # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def transcribe: (path, Params, ?n_processors: Integer) -> self - | (path, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Whisper::Params, ?n_processors: Integer) -> self + | (path, Whisper::Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -74,7 +74,7 @@ module Whisper # puts segment.text # end # - # Returns an Enumerator if no block given: + # Returns an `Enumerator` if no block given: # # whisper.transcribe("path/to/audio.wav", params) # enum = whisper.each_segment @@ -91,25 +91,25 @@ module Whisper # def full_lang_id: () -> Integer - # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t0(3) # => 1668 (16680 ms) # def full_get_segment_t0: (Integer) -> Integer - # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t1(3) # => 1668 (16680 ms) # def full_get_segment_t1: (Integer) -> Integer - # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + # Whether the next segment indexed by `segment_index` is predicated as a speaker turn. # # full_get_segment_speacker_turn_next(3) # => true # def full_get_segment_speaker_turn_next: (Integer) -> (true | false) - # Text of a segment indexed by +segment_index+. + # Text of a segment indexed by `segment_index`. # # full_get_segment_text(3) # => "ask not what your country can do for you, ..." # @@ -117,27 +117,27 @@ module Whisper def full_get_segment_no_speech_prob: (Integer) -> Float - # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - # Not thread safe for same context + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + # Not thread safe for same context # Uses the specified decoding strategy to obtain the text. # - # The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # The second argument `samples` must be an array of samples, respond to `:length`, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. # - def full: (Params, Array[Float] samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self + def full: (Whisper::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self - # Split the input audio in chunks and process each chunk separately using whisper_full_with_state() - # Result is stored in the default state of the context - # Not thread safe if executed in parallel on the same context. - # It seems this approach can offer some speedup in some cases. + # Split the input audio in chunks and process each chunk separately using `whisper_full_with_state()` + # Result is stored in the default state of the context + # Not thread safe if executed in parallel on the same context. + # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self + def full_parallel: (Whisper::Params, Array[Float], ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self def to_srt: () -> String def to_webvtt: () -> String @@ -217,35 +217,35 @@ module Whisper def translate: () -> (true | false) def no_context=: (boolish) -> boolish - # If true, does not use past transcription (if any) as initial prompt for the decoder. + # If `true`, does not use past transcription (if any) as initial prompt for the decoder. # def no_context: () -> (true | false) def single_segment=: (boolish) -> boolish - # If true, forces single segment output (useful for streaming). + # If `true`, forces single segment output (useful for streaming). # def single_segment: () -> (true | false) def print_special=: (boolish) -> boolish - # If true, prints special tokens (e.g. , , , etc.). + # If `true`, prints special tokens (e.g. , , , etc.). # def print_special: () -> (true | false) def print_progress=: (boolish) -> boolish - # If true, prints progress information. + # If `true`, prints progress information. # def print_progress: () -> (true | false) def print_realtime=: (boolish) -> boolish - # If true, prints results from within whisper.cpp. (avoid it, use callback instead) + # If `true`, prints results from within whisper.cpp. (avoid it, use callback instead) # def print_realtime: () -> (true | false) - # If true, prints timestamps for each text segment when printing realtime. + # If `true`, prints timestamps for each text segment when printing realtime. # def print_timestamps=: (boolish) -> boolish @@ -253,19 +253,19 @@ module Whisper def suppress_blank=: (boolish) -> boolish - # If true, suppresses blank outputs. + # If `true`, suppresses blank outputs. # def suppress_blank: () -> (true | false) def suppress_nst=: (boolish) -> boolish - # If true, suppresses non-speech-tokens. + # If `true`, suppresses non-speech-tokens. # def suppress_nst: () -> (true | false) def token_timestamps=: (boolish) -> boolish - # If true, enables token-level timestamps. + # If `true`, enables token-level timestamps. # def token_timestamps: () -> (true | false) @@ -277,16 +277,16 @@ module Whisper def split_on_word=: (boolish) -> boolish - # If true, split on word rather than on token (when used with max_len). + # If `true`, split on word rather than on token (when used with max_len). # def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS def carry_initial_prompt=: (boolish) -> boolish - # Tokens to provide to the whisper decoder as initial prompt - # these are prepended to any existing text context from a previous call - # use whisper_tokenize() to convert text to tokens. + # Tokens to provide to the whisper decoder as initial prompt + # these are prepended to any existing text context from a previous call + # use whisper_tokenize() to convert text to tokens. # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) @@ -294,7 +294,7 @@ module Whisper def diarize=: (boolish) -> boolish - # If true, enables diarization. + # If `true`, enables diarization. # def diarize: () -> (true | false) @@ -423,7 +423,7 @@ module Whisper # def on_new_segment: { (Segment) -> void } -> void - # Hook called on progress update. Yields each progress Integer between 0 and 100. + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. # def on_progress: { (Integer progress) -> void } -> void @@ -431,7 +431,7 @@ module Whisper # def on_encoder_begin: { () -> void } -> void - # Call block to determine whether abort or not. Return +true+ when you want to abort. + # Call block to determine whether abort or not. Return `true` when you want to abort. # # params.abort_on do # if some_condition @@ -504,13 +504,13 @@ module Whisper # Yields each Whisper::Token: # - # whisper.each_segment.first.each_token do |token| - # p token - # end + # whisper.each_segment.first.each_token do |token| + # p token + # end # - # Returns an Enumerator if no block is given: + # Returns an `Enumerator` if no block is given: # - # whisper.each_segment.first.each_token.to_a # => [#, ...] + # whisper.each_segment.first.each_token.to_a # => [#, ...] # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] @@ -518,7 +518,7 @@ module Whisper def to_webvtt_cue: () -> String - # Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next + # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # # whisper.each_segment do |segment| # segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:} @@ -569,7 +569,7 @@ module Whisper # [EXPERIMENTAL] Token-level timestamps with DTW # - # Do not use if you haven't computed token-level timestamps with dtw. + # Do not use if you haven't computed token-level timestamps with dtw. # Roughly corresponds to the moment in audio in which the token was output. # def t_dtw: () -> Integer @@ -580,14 +580,14 @@ module Whisper # Start time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def start_time: () -> Integer # End time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def end_time: () -> Integer diff --git a/bindings/ruby/test/test_package.rb b/bindings/ruby/test/test_package.rb index 108f34efbeb..f99012cce83 100644 --- a/bindings/ruby/test/test_package.rb +++ b/bindings/ruby/test/test_package.rb @@ -1,12 +1,12 @@ require_relative "helper" require 'tempfile' require 'tmpdir' -require 'shellwords' +require 'open3' class TestPackage < TestBase def test_build Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path, exception: true) assert file.size > 0 assert_path_exist file.to_path end @@ -20,7 +20,7 @@ def setup def test_install gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), exception: true assert_installed dir, gemspec.version end end @@ -29,13 +29,14 @@ def test_install_with_coreml omit_unless RUBY_PLATFORM.match?(/darwin/) do gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), "--", "--enable-whisper-coreml", exception: true assert_installed dir, gemspec.version libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib") assert_nothing_raised do system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true end - assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`) + output, status = Open3.capture2("ruby", "-I", libdir, "-r", "whisper", "-e", "puts Whisper.system_info_str") + assert_match /COREML = 1/, output end end end From c33c5618b72bb345df029b730b36bc0e369845a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bjarke=20Viks=C3=B8e?= <164612031+bviksoe@users.noreply.github.com> Date: Sun, 10 May 2026 16:24:12 +0200 Subject: [PATCH 002/289] whisper : fix incorrect timestamps, usually near silences (#2279) * Incorrect timetstamps Fixes #2271 - Adds consecutive timestamps after end of last segment as the new starting ts - Add these timestamp to output when "print-special" enabled - Fixes fflush usage in live reporting I was not able to test this with the special "token_timestamps" option. * Skip initial timestamp --- src/whisper.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 2f356da0f06..6176d21f53c 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -7659,11 +7659,14 @@ int whisper_full_with_state( } } text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + t0 = t1; + while (i + 1 < (int) tokens_cur.size() && tokens_cur[i + 1].id > whisper_token_beg(ctx)) { i++; + if (params.print_special) { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + t0 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); } - i--; - t0 = t1; i0 = i + 1; speaker_turn_next = false; } @@ -7680,8 +7683,8 @@ int whisper_full_with_state( printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); - fflush(stdout); } + fflush(stdout); } result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); From 338cce1e58133261753243802a0e7a430118866d Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Tue, 12 May 2026 07:36:00 +0200 Subject: [PATCH 003/289] server: Add support for controlling token_timestamps directly (#3785) --- examples/server/server.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f6a7a83181a..08c0988d2be 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -101,6 +101,7 @@ struct whisper_params { bool print_realtime = false; bool print_progress = false; bool no_timestamps = false; + bool token_timestamps = true; bool use_gpu = true; bool flash_attn = true; int32_t gpu_device = 0; @@ -550,6 +551,12 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content); } + if (req.has_file("token_timestamps")) + { + params.token_timestamps = parse_str_to_bool(req.get_file_value("token_timestamps").content); + } else { + params.token_timestamps = !params.no_timestamps; + } if (req.has_file("language")) { params.language = req.get_file_value("language").content; @@ -690,10 +697,10 @@ int main(int argc, char ** argv) { if (params.dtw == "large.v3") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; } - if (params.dtw == "large.v3.turbo") { + if (params.dtw == "large.v3.turbo") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO; } - + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); return 3; @@ -939,7 +946,7 @@ int main(int argc, char ** argv) { wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps; + wparams.token_timestamps = params.token_timestamps; wparams.no_context = params.no_context; wparams.suppress_nst = params.suppress_nst; @@ -1043,7 +1050,7 @@ int main(int argc, char ** argv) { res.set_content(ss.str(), "text/vtt"); } else if (params.response_format == vjson_format) { /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); + std::string results = output_str(ctx, params, pcmf32s); json jres = json{ {"task", params.translate ? "translate" : "transcribe"}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, @@ -1088,7 +1095,7 @@ int main(int argc, char ** argv) { segment["tokens"].push_back(token.id); json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; - if (!params.no_timestamps) { + if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; word["end"] = token.t1 * 0.01; word["t_dtw"] = token.t_dtw; From f08258abd74b995bb95d8005103f72f1afd66a8a Mon Sep 17 00:00:00 2001 From: annaeina <2846698728@qq.com> Date: Wed, 13 May 2026 13:32:00 +0800 Subject: [PATCH 004/289] whisper : fix max_tokens skipping remaining audio (#3798) * whisper: fix max_tokens skipping remaining audio * add PR reference comment as suggested Co-authored-by: Georgi Gerganov * fix(ci): enable artifact overwrite --- .github/workflows/build.yml | 1 + bindings/go/pkg/whisper/context_test.go | 48 +++++++++++++++++++++++++ src/whisper.cpp | 12 +++++++ 3 files changed, 61 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb115b22abb..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -662,6 +662,7 @@ jobs: with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true - name: Upload ggml base dll uses: actions/upload-artifact@v6 diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index e98a4c2b80b..79f6a593024 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -2,6 +2,7 @@ package whisper_test import ( "os" + "strings" "testing" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" @@ -92,6 +93,53 @@ func TestProcess(t *testing.T) { assert.NoError(err) } +func TestProcessMaxTokensPerSegment(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Decode the WAV file - load the full buffer + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + assert.NoError(err) + assert.Equal(uint16(1), dec.NumChans) + + data := buf.AsFloat32Buffer().Data + + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + context, err := model.NewContext() + assert.NoError(err) + + context.SetMaxTokensPerSegment(5) + + err = context.Process(data, nil, nil, nil) + assert.NoError(err) + + var text strings.Builder + nSegments := 0 + for { + segment, err := context.NextSegment() + if err != nil { + break + } + nSegments++ + text.WriteString(segment.Text) + } + + assert.Greater(nSegments, 1) + assert.Contains(text.String(), "country") +} + func TestDetectedLanguage(t *testing.T) { assert := assert.New(t) diff --git a/src/whisper.cpp b/src/whisper.cpp index 6176d21f53c..210ca597fb4 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6216,6 +6216,13 @@ static void whisper_process_logits( } } + // ref: https://github.com/ggml-org/whisper.cpp/pull/3798 + if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; @@ -7725,7 +7732,12 @@ int whisper_full_with_state( } // ref: https://github.com/ggml-org/whisper.cpp/pull/2629 + const bool max_tokens_timestamp_ending = params.max_tokens > 0 && + !params.single_segment && + tokens_cur.size() > (size_t) params.max_tokens; + const bool single_timestamp_ending = tokens_cur.size() > 1 && + !max_tokens_timestamp_ending && tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); if (single_timestamp_ending) { From a604a9b5b0ff9108191769a09843ae325c6c0d7f Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Wed, 13 May 2026 08:54:56 +0200 Subject: [PATCH 005/289] server: fix params leak between requests (#3784) --- examples/server/server.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 08c0988d2be..c582c448de1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -824,7 +824,7 @@ int main(int argc, char ** argv) { } auto audio_file = req.get_file_value("file"); - // check non-required fields + whisper_params params = default_params; get_req_parameters(req, params); std::string filename{audio_file.filename}; @@ -1127,9 +1127,6 @@ int main(int argc, char ** argv) { res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } - - // reset params to their defaults - params = default_params; }); svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ std::lock_guard lock(whisper_mutex); From 3e9b7d0fef3528ee2208da3cdb873a2c53d2ae2f Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Wed, 13 May 2026 10:37:28 +0200 Subject: [PATCH 006/289] server : fix no_speech_thold not being read (#3783) --- examples/server/server.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c582c448de1..735255b6290 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -87,7 +87,7 @@ struct whisper_params { float logprob_thold = -1.00f; float temperature = 0.00f; float temperature_inc = 0.20f; - float no_speech_thold = 0.6f; + float no_speech_thold = 0.6f; bool debug_mode = false; bool translate = false; @@ -527,6 +527,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content); } + if (req.has_file("no_speech_thold")) + { + params.no_speech_thold = std::stof(req.get_file_value("no_speech_thold").content); + } if (req.has_file("debug_mode")) { params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content); @@ -762,6 +766,7 @@ int main(int argc, char ** argv) { -F file="@<file-path>" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ + -F no_speech_thold="0.6" \ -F response_format="json" @@ -940,7 +945,7 @@ int main(int argc, char ** argv) { wparams.beam_search.beam_size = params.beam_size; wparams.temperature = params.temperature; - wparams.no_speech_thold = params.no_speech_thold; + wparams.no_speech_thold = params.no_speech_thold; wparams.temperature_inc = params.temperature_inc; wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; From ff5704a416813610e30e54d864c5af1be41288c6 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 1 May 2026 23:02:24 -0700 Subject: [PATCH 007/289] opencl: Adreno optimization for MoE - MxFP4 (llama/22301) * MoE Mxfp4 CLC kernel added, router reorder on GPU * Pass test-backend-ops for MoE mxfp4 Adreno CLC * remove putenv in llama-model.cpp * fix indent style and whitespace * opencl: remove unnecessary headers * opencl: do not save cl_program objects * opencl: remove unnecessary assert * fix precision issue --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 451 +++++++++++++++--- ggml/src/ggml-opencl/kernels/cvt.cl | 87 ++++ .../kernels/gemm_moe_mxfp4_f32_ns.cl | 302 ++++++++++++ .../kernels/gemv_moe_mxfp4_f32_ns.cl | 161 +++++++ ggml/src/ggml-opencl/kernels/moe_reorder_b.cl | 30 ++ .../ggml-opencl/kernels/moe_sort_by_expert.cl | 82 ++++ 7 files changed, 1040 insertions(+), 77 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/moe_reorder_b.cl create mode 100644 ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 5ed83eeb48a..35d425a431f 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -107,6 +107,10 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 + gemm_moe_mxfp4_f32_ns + gemv_moe_mxfp4_f32_ns + moe_reorder_b + moe_sort_by_expert mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 11f72a5198a..74948c27e4e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -416,6 +416,15 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_src0; ggml_cl_buffer prealloc_src1; + // prealloc buffers for MoE router table preprocess + bool toggle_reorder = false; + ggml_cl_buffer prealloc_post_router; + ggml_cl_buffer prealloc_emap; + ggml_cl_buffer prealloc_hist; + ggml_cl_buffer prealloc_tile_offset; + ggml_cl_buffer prealloc_total_tiles; + ggml_cl_buffer prealloc_slot_counter; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; @@ -531,6 +540,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; + cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -587,6 +597,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; + cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; + cl_kernel kernel_moe_reorder_b; + cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; @@ -945,6 +958,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); @@ -2864,6 +2879,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_reorder_b + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_reorder_b.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_reorder_b.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_sort_by_expert + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_sort_by_expert.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_sort_by_expert.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_noshuffle_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3651,13 +3737,12 @@ struct ggml_tensor_extra_cl_mxfp4 { CL_CHECK(clReleaseMemObject(e)); e = nullptr; } - if (q != nullptr) { + if (q_img != nullptr) { CL_CHECK(clReleaseMemObject(q_img)); - q = nullptr; + q_img = nullptr; } - // Currently, q_img and d_img are not used. They can be image1d_buffer_t + // Currently, e_img is not used. They can be image1d_buffer_t // that wraps around q and d to utilize image access path. - q_img = nullptr; e_img = nullptr; size_q = 0; size_e = 0; @@ -4740,7 +4825,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -5151,8 +5236,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5172,9 +5258,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); @@ -5912,7 +6010,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5936,7 +6034,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); @@ -12763,6 +12862,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } } +static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) { + cl_int err; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + cl_ulong offset = extra->offset + src->view_offs; + + const int ne21 = src->ne[1]; + const int nb21 = src->nb[1]; + const int ne02 = nb21 / src->nb[0]; + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + + cl_buffer_region region; + region.origin = offset; + region.size = nb21 * ne21; + cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size); + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile); + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int)); + region.origin = 0; + region.size = sizeof(int); + cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Histogram + cl_kernel kernel = backend_ctx->kernel_moe_histogram; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); + + size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast(ne20), 1}; + size_t histogram_local_size[] = {64, static_cast(ne20), 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + // Scan + kernel = backend_ctx->kernel_moe_scan; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + + size_t scan_global_size[] = {1}; + size_t scan_local_size[] = {1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src); + + // Fill + kernel = backend_ctx->kernel_moe_fill; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size)); + + size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1}; + size_t fill_local_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src); + + // Scatter + kernel = backend_ctx->kernel_moe_scatter; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + CL_CHECK(clReleaseMemObject(original_router_buf)); + CL_CHECK(clReleaseMemObject(hist_buf)); + CL_CHECK(clReleaseMemObject(tile_offset_buf)); + CL_CHECK(clReleaseMemObject(total_tiles_buf)); + CL_CHECK(clReleaseMemObject(slot_counter_buf)); + CL_CHECK(clReleaseMemObject(post_router_buf)); + CL_CHECK(clReleaseMemObject(emap_buf)); +} + static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12824,6 +13035,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; const int r2 = ne12/ne02; const int r3 = ne13/ne03; @@ -12836,6 +13048,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, int nrows = 1; // number of row in src1 int ndst = 4; // number of values produced by each subgroup + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + cl_kernel kernel; // subgroup mat vec @@ -12967,11 +13182,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, size_t local_size[3] = {64, 2, 1}; size_t global_size[3] = {64, 2, 1}; - cl_mem src1_sub_buffer, buf_src1_image, buf_src2; - - int tile_size = 320; if (ne12 == 1) { // for gemv - kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32; + kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; // create a sub_buffer for src2 cl_buffer_region region; @@ -12985,78 +13199,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm - kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32; - - // preprocess router table - int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size; - void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short)); - void * host_src2 = malloc(ne21 * nb21); - CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL)); - int total_experts = nb21 / nb20; - int out_idx = 0; - for (int i_expert = 0; i_expert < ne02; i_expert++) { - for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) { - for (int j = 0; j < ne21; j++) { - for (int i = 0; i < ne20; i++) { - int expert = ((int *)host_src2)[j * total_experts + i]; - if (i_expert == expert) { - ((short *)host_src2_reorder)[out_idx] = static_cast(expert); - ((short *)host_src2_reorder)[out_idx + 1] = static_cast(j * ne11 + (i % ne11)); - ((short *)host_src2_reorder)[out_idx + 2] = static_cast(j * ne20 + i); - ((short *)host_src2_reorder)[out_idx + 3] = static_cast(i_tile); - out_idx += 4; - } - } - } - } + kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } - buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status); + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + GGML_ASSERT(backend_ctx->prealloc_post_router.buffer); + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); CL_CHECK(status); - // set thread grid - global_size[0] = static_cast(tile_size); - global_size[2] = static_cast(ne20 * ne21 * num_tiles_per_expert); - } + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - // create a sub_buffer for src1 - cl_buffer_region region; - region.origin = offset1; - region.size = ne10 * ne11 * ne12 * sizeof(float); - src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create image for src1 - cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; - cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; - buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); - CL_CHECK(status); - - // Set kernel args - int arg_idx = 0; - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - if (ne12 == 1) { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); - } else { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size)); - } + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - // launch kernel - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(src1_sub_buffer)); - CL_CHECK(clReleaseMemObject(buf_src1_image)); - CL_CHECK(clReleaseMemObject(buf_src2)); + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } return; - } // else fallback to generic kernel + } // fallback to generic MoE mxfp4 kernel #endif // GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_SOA_Q @@ -14002,6 +14292,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + const int ne21 = dst->ne[1]; + if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) { + backend_ctx->toggle_reorder = true; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS } static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index f3937d8304c..c1ad46f4435 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans( b->e = src_e[src_blk_offset]; } +kernel void kernel_convert_block_mxfp4_trans4_ns( + global struct block_mxfp4 * src0, + __global uint * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + dst_e[dst_blk_offset] = b->e; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_mxfp4_trans4_ns( + __global uint * src_q, + __global uchar * src_e, + __global struct block_mxfp4 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_mxfp4 * b = dst0 + dst_blk_offset; + b->e = src_e[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK_MXFP4 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + + //------------------------------------------------------------------------------ // block_q8_0 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e404f392bdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +static inline half e8m0_to_fp16(uchar x) { + ushort bits; + bits = (ushort)(x) - (ushort)(112); + bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10); + return as_half(bits); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_mxfp4_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uchar * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current mxfp4 block + uint s_offset = s_sub_offset + get_global_id(0); + float s = e8m0_to_fp32(src0_d[s_offset]); + + // Load 16 fp4 (64-bits) in transposed layout + uint2 mxfp4x16; + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 fp4 (64-bits) in transposed layout + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e4b44c1a56a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -0,0 +1,161 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32_ns( + __global uint * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl new file mode 100644 index 00000000000..e6295c81648 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define QK4_0 32 + +kernel void kernel_moe_reorder_b( + global float4 * src, + global uint * router, + global float4 * dst, + global int * total_tiles, + uint K, + ushort map_ratio, + uint tile_size +) { + uint k_4 = get_global_id(0); + uint post_router_idx = get_global_id(1); + + if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) { + return; + } + + uint router_idx = router[post_router_idx]; + + float4 out = (float4)(0); + if (router_idx != 0xFFFFFFFF) { + ushort activation_idx = router_idx / map_ratio; + out = src[activation_idx * K / 4 + k_4]; + } + + dst[post_router_idx * K / 4 + k_4] = out; +} diff --git a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl new file mode 100644 index 00000000000..d9703429b11 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl @@ -0,0 +1,82 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void kernel_moe_histogram( + __global const int * input, + __global int * hist, + uint N, + uint topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int expert_id = input[n * n_experts + k]; + atomic_inc(&hist[expert_id]); +} + +__kernel void kernel_moe_scan( + __global int * hist, + __global int * tile_offset, + __global int * total_tiles, + __global int * slot_counter, + int tile_size, + uint n_experts +) { + int offset = 0; + for (int v = 0; v < n_experts; v++) { + int count = hist[v]; + int tiles = (count + tile_size - 1) / tile_size; + tile_offset[v] = offset; + offset += tiles; + hist[v] = 0; + slot_counter[v] = 0; + } + + *total_tiles = offset; +} + +__kernel void kernel_moe_scatter( + __global const int * input, + __global int * post_router, + __global ushort * emap, + __global const int * tile_offset, + __global int * slot_counter, + int N, + int topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int val = input[n * n_experts + k]; + + int local_slot = atomic_inc(&slot_counter[val]); + + int tile_idx = tile_offset[val] + (local_slot / 32); + int lane = local_slot % 32; + int out_pos = tile_idx * 32 + lane; + + post_router[out_pos] = n * topK + k; + emap[tile_idx] = val; +} + +__kernel void kernel_moe_fill( + __global int * post_router, + __global int * total_tiles, + int tile_size +) { + int tile_id = get_global_id(0); + int vec_id_in_tile = get_global_id(1); + + if (tile_id < total_tiles[0]) { + post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF; + } +} From 9ab94b8cdac1845f492421321aa14cebdc705853 Mon Sep 17 00:00:00 2001 From: JusteLeo Date: Sat, 2 May 2026 15:28:50 +0200 Subject: [PATCH 008/289] ggml-virtgpu: fix circular dependency in headers (llama/22557) --- ggml/src/ggml-virtgpu/virtgpu-shm.cpp | 1 + ggml/src/ggml-virtgpu/virtgpu.cpp | 1 + ggml/src/ggml-virtgpu/virtgpu.h | 2 -- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp index ce6b3b3e607..7f2c2322d91 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -1,6 +1,7 @@ #include "virtgpu-shm.h" #include "virtgpu.h" +#include "ggml-remoting.h" #include diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp index a84a77399d9..e3ae1cc75e0 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -1,4 +1,5 @@ #include "virtgpu.h" +#include "ggml-remoting.h" #include #include diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h index f82d8fb50ba..6b8de583893 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.h +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -18,8 +18,6 @@ #include -#include "ggml-remoting.h" - #define VIRGL_RENDERER_UNSTABLE_APIS 1 #include "apir_hw.h" #include From 3bcac0a0c7d73128fa9cf6e65ff1d16ff8438933 Mon Sep 17 00:00:00 2001 From: lucy <154630366+lucyknada@users.noreply.github.com> Date: Sat, 2 May 2026 16:19:25 -0400 Subject: [PATCH 009/289] fix: CUDA device PCI bus ID de-dupe OOMing (ignoring other 3 gpus entirely) (llama/22533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: CUDA device PCI bus ID detection for multi-GPU de-dupe * HIP, MUSA macros --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fbe0fa06242..8d21b2267f5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; - char pci_bus_id[16] = {}; - snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + char pci_bus_id[32] = {}; + CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i)); dev_ctx->pci_bus_id = pci_bus_id; dev_ctx->op_offload_min_batch_size = min_batch_size; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 78ca364d38f..e5d363c65d1 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -55,6 +55,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 8aa056e9174..940c34a9fb2 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -39,6 +39,7 @@ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t From d1d0dc2348f6b294598725ef8cd3d40652fb674d Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Sun, 3 May 2026 23:52:53 -0400 Subject: [PATCH 010/289] ggml-webgpu: add layer norm ops (llama/22406) * shader(norm): add layer norm ops * shader(norm): stablize floating point computation with Kahan summation and handle mixed types * shader(norm): remove the non-contiguous strides * shader(norm): use the original implementation rather than the kahan summation --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 32 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 + .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 97 +++++++++++++++---- 3 files changed, 107 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index cff93b8d170..c6dc2c21147 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { /** Row Norm **/ struct ggml_webgpu_row_norm_pipeline_key { - ggml_op op; - bool inplace; + ggml_op op; + ggml_type src_type; + ggml_type dst_type; + bool inplace; bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { - return op == other.op && inplace == other.inplace; + return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace; } }; @@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } @@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); @@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib { defines.push_back("RMS_NORM"); variant = "rms_norm"; break; + case GGML_OP_NORM: + defines.push_back("NORM"); + variant = "norm"; + break; case GGML_OP_L2_NORM: defines.push_back("L2_NORM"); variant = "l2_norm"; @@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (key.src_type == GGML_TYPE_F32) { + defines.push_back("SRC_F32"); + variant += "_src_f32"; + } else if (key.src_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } + + if (key.dst_type == GGML_TYPE_F32) { + defines.push_back("DST_F32"); + variant += "_dst_f32"; + } else if (key.dst_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } + const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cab0aead198..12f60a9900e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2927,6 +2927,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, } else { return ggml_webgpu_row_norm(ctx, src0, node); } + case GGML_OP_NORM: case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_NORM: case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index bd8d32bded7..5eaf5e7bbe5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -1,20 +1,17 @@ -#ifdef INPLACE -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif -@group(0) @binding(1) -var params: Params; +#ifdef SRC_F16 +#define SRC_TYPE f16 #else -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var dst: array; +#define SRC_TYPE f32 +#endif -@group(0) @binding(2) -var params: Params; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 #endif struct Params { @@ -40,9 +37,20 @@ struct Params { }; @group(0) @binding(0) -var src: array; +var src: array; -var scratch: array; +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#endif + +var scratch: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3, if (col >= params.ne0) { break; } - sum += pow(src[i_src_row + col], 2.0); + let v = f32(src[i_src_row + col]); +#ifdef NORM + sum += v; +#else + sum += v * v; +#endif col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - var offset: u32 = WG_SIZE / 2; + + var offset: u32 = WG_SIZE / 2u; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; } - offset = offset / 2; + offset /= 2u; workgroupBarrier(); } sum = scratch[0]; -#ifdef RMS_NORM +#ifdef NORM + let mean = sum / f32(params.ne0); + var sq_sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let v = f32(src[i_src_row + col]); + let d = v - mean; + sq_sum += d * d; + col += WG_SIZE; + } + + workgroupBarrier(); + scratch[lid.x] = sq_sum; + workgroupBarrier(); + offset = WG_SIZE / 2u; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset /= 2u; + workgroupBarrier(); + } + + let variance = scratch[0] / f32(params.ne0); + let scale = 1.0 / sqrt(variance + params.eps); +#elif defined(RMS_NORM) let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); #elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif +#ifdef NORM + let mean_val = mean; +#else + let mean_val = 0.0f; +#endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { break; } - update(i_src_row + col, i_dst_row + col, scale); + let i_src = i_src_row + col; + let i_dst = i_dst_row + col; + let v = src[i_src]; +#ifdef INPLACE + src[i_dst] = scale * (v - mean_val); +#else + dst[i_dst] = scale * (v - mean_val); +#endif col += WG_SIZE; } } From 0fffe2cdb87a4bf0b91dde4068858f6b7ef0838a Mon Sep 17 00:00:00 2001 From: Atomic-Germ <97569476+Atomic-Germ@users.noreply.github.com> Date: Sun, 3 May 2026 22:49:29 -0700 Subject: [PATCH 011/289] vulkan: delete dead GGML_VK_MAX_NODES def (llama/22621) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c2f1883328f..423e01dbff1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -111,8 +111,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 -#define GGML_VK_MAX_NODES 8192 - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ From 36a83b84bb29651e9802b8d287d3941240fc860b Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Mon, 4 May 2026 22:24:05 +0800 Subject: [PATCH 012/289] CUDA: use fastdiv for batch index split in get_rows (llama/22650) --- ggml/src/ggml-cuda/getrows.cu | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index e99cba63d34..36b840e8148 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -6,17 +6,18 @@ template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; @@ -42,17 +43,18 @@ template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; if (i00 >= ne00) { return; @@ -115,10 +117,14 @@ static void get_rows_cuda_q( GGML_ASSERT(ne00 % 2 == 0); + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -146,10 +152,14 @@ static void get_rows_cuda_float( const size_t s12 = nb12 / sizeof(int32_t); // const size_t s13 = nb13 / sizeof(int32_t); + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); From 254f951db8ffd75512c87dc3a234a4b84bf8c6ad Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 4 May 2026 21:13:31 +0200 Subject: [PATCH 013/289] kleidiai : update to v1.24.0 and use release archive (llama/22549) --- ggml/src/ggml-cpu/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index c1c225f0197..869c7b238bf 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -578,13 +578,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.22.0") - set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") + set(KLEIDIAI_COMMIT_TAG "v1.24.0") + set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz") + set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00") set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} + URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5} ) if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) From 4794432337769d04ffb7d443747d69e6c0fd7469 Mon Sep 17 00:00:00 2001 From: Ismail <115064057+AlrIsmail@users.noreply.github.com> Date: Tue, 5 May 2026 04:05:05 +0200 Subject: [PATCH 014/289] ggml : implement fast walsh-hadamard transform for kv rotation (#21352) (llama/22631) --- ggml/include/ggml.h | 11 +++++ ggml/src/ggml-cpu/ggml-cpu.c | 6 +++ ggml/src/ggml-cpu/ops.cpp | 88 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml.c | 10 ++++ 5 files changed, 116 insertions(+) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e3783136..3357a0d9985 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -438,6 +438,12 @@ extern "C" { GGML_PREC_F32 = 10, }; + // op hint + enum ggml_op_hint { + GGML_HINT_NONE = 0, + GGML_HINT_SRC0_IS_HADAMARD = 1, + }; + // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -1419,6 +1425,11 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + // change the hint of a matrix multiplication + GGML_API void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint); + // indirect matrix multiplication GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b3eb5b5ce6..2d6cc1fcd46 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) { + ggml_compute_forward_fwht(params, dst); + return; + } + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0..211f1ba1b2f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11212,3 +11212,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_ } } } + +static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t n = ne10; + GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2 + + const int64_t nr = ne11 * ne12 * ne13; + const int64_t rows_per_thread = (nr + nth - 1) / nth; + const int64_t start_row = ith * rows_per_thread; + const int64_t end_row = MIN(start_row + rows_per_thread, nr); + + const float scale = 1.0f / sqrtf((float)n); + +#if defined(GGML_SIMD) + const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f); +#endif + + for (int64_t r = start_row; r < end_row; r++) { + const int64_t i13 = r / (ne11 * ne12); + const int64_t i12 = (r - i13 * ne11 * ne12) / ne11; + const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11; + + const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3); + + for (int64_t j = 0; j < n; j++) { + dst_row[j] = src_row[j] * scale; + } + + // Scalar passes +#if defined(GGML_SIMD) + const int step = GGML_F32_EPR; +#else + const int step = n; +#endif + for (int64_t len = 1; len < step && len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j++) { + float u = dst_row[i + j]; + float v = dst_row[i + len + j]; + dst_row[i + j] = u + v; + dst_row[i + len + j] = u - v; + } + } + } + + // SIMD passes using GGML_F32_VEC_* macros for multi-architecture support +#if defined(GGML_SIMD) + for (int64_t len = step; len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j += step) { + GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j); + GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j); + + GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v)); + GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one)); + } + } + } +#endif + } +} + +void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src1 = dst->src[1]; + + switch (src1->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fwht_f32(params, dst); + } + break; + default: + { + GGML_ABORT("fatal error - fwht is F32 only"); + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3fa1443abc4..29efdeee37f 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -111,6 +111,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81343eeb14c..191cf2fa106 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3264,6 +3264,16 @@ void ggml_mul_mat_set_prec( ggml_set_op_params_i32(a, 0, prec_i32); } +void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + + const int32_t hint_i32 = (int32_t) hint; + + ggml_set_op_params_i32(a, 1, hint_i32); +} + // ggml_mul_mat_id /* From 6f6103f6d0034945a9377d16e29cf0d3ec2b4c35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 May 2026 06:35:07 +0300 Subject: [PATCH 015/289] llama : add option to save memory in device buffers (llama/22679) * llama : add option to save memory in device buffers * tests : extend llama-save-load-state --- ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 42 +++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.cpp | 19 ++++++----- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index a6c1dab5515..4718ca083b0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst); void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); // finds the Metal buffer that contains the tensor data on the GPU device diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fe90aafe7bc..fab7891c008 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1,6 +1,7 @@ #import "ggml-metal-device.h" #import "ggml-impl.h" +#import "ggml-backend-impl.h" #include @@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten } } +bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context; + + const size_t size = ggml_nbytes(src); + + // if both buffers are shared, we can use memcpy directly + if (buf_dst->is_shared && buf_src->is_shared) { + memcpy(dst->data, src->data, size); + return true; + } + + // for private buffers, we need to use Metal blit commands + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + id cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } + + return true; +} + void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { if (buf->is_shared) { memset(buf->all_data, value, buf->all_size); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index cc329d67594..35774254983 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -17,6 +17,9 @@ // note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices static int g_devices = 1; +// forward declaration +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer); + //////////////////////////////////////////////////////////////////////////////// // backend interface //////////////////////////////////////////////////////////////////////////////// @@ -68,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -144,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { From 716acdb08212087ca61b56f569e354068e6613eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 May 2026 13:14:32 +0300 Subject: [PATCH 016/289] ggml : bump version to 0.11.0 (ggml/1478) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c97f681988b..8dd4d64063f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 2) +set(GGML_VERSION_MINOR 11) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0bafd810b60a76bdb9e9784759cd22511d779c40 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 5 May 2026 13:47:13 +0300 Subject: [PATCH 017/289] rpc : use graph uid instead of graph cache (llama/22701) Store the last graph uid and compare against it to determine if the same graph is being computed. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 38 +++++++--------------------------- 1 file changed, 7 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 7176d2feef9..1cb8f563d85 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; -struct graph_cache { - - bool is_cached(const ggml_cgraph * cgraph) { - if ((int)last_graph.size() != cgraph->n_nodes) { - return false; - } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { - return false; - } - } - return true; - } - - void add(const ggml_cgraph * cgraph) { - last_graph.resize(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); - } - } - - std::vector last_graph; -}; - struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - graph_cache gc; + uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -717,7 +693,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = rpc_ctx->gc.is_cached(cgraph); + bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -725,7 +701,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->gc.add(cgraph); + rpc_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .device = */ device, - /* .name = */ dev_name, - /* .gc = */ {}, + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { From f83b6bdc44c5e44607cae112cfa883686cd57271 Mon Sep 17 00:00:00 2001 From: lhez Date: Sun, 10 May 2026 14:52:20 +0300 Subject: [PATCH 018/289] opencl: refactor Adreno q4_0 (llama/22335) --- ggml/src/ggml-opencl/CMakeLists.txt | 10 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 944 +++++++----------- ...b_Bi_8x4.cl => gemm_noshuffle_q4_0_f32.cl} | 2 +- ..._f32_8x4.cl => gemm_noshuffle_q8_0_f32.cl} | 2 +- ..._general.cl => gemv_noshuffle_q4_0_f32.cl} | 10 +- ...fle.cl => gemv_noshuffle_q4_0_f32_spec.cl} | 10 +- ...q8_0_f32.cl => gemv_noshuffle_q8_0_f32.cl} | 0 7 files changed, 355 insertions(+), 623 deletions(-) rename ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl => gemm_noshuffle_q4_0_f32.cl} (99%) rename ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl => gemm_noshuffle_q8_0_f32.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl => gemv_noshuffle_q4_0_f32.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl => gemv_noshuffle_q4_0_f32_spec.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl => gemv_noshuffle_q8_0_f32.cl} (100%) diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 35d425a431f..0a45a4daa13 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -66,8 +66,6 @@ set(GGML_OPENCL_KERNELS diag div gelu - gemv_noshuffle_general - gemv_noshuffle get_rows glu group_norm @@ -75,7 +73,6 @@ set(GGML_OPENCL_KERNELS im2col_f32 im2col_f16 mean - mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row mul_mv_f16_f32_l4 @@ -120,12 +117,15 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_k_f32_l4_lm mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm - mul_mm_q8_0_f32_8x4 + gemv_noshuffle_q4_0_f32 + gemv_noshuffle_q4_0_f32_spec + gemm_noshuffle_q4_0_f32 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_iq4_nl_f32 gemm_noshuffle_iq4_nl_f32 - gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q8_0_f32 + gemm_noshuffle_q8_0_f32 gemv_noshuffle_q4_k_f32 gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 74948c27e4e..8c7bf98c16f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -731,22 +731,16 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_16_4x1; // Gemm and Gemv related programs, kernels, etc - cl_program program_CL_gemm; - cl_program program_CL_gemv_general; - cl_program program_CL_gemv_4096_1_11008; - cl_program program_CL_gemv_4096_1_4096; - cl_program program_CL_gemv_11008_1_4096; - cl_program program_CL_gemv_32000_1_4096; - cl_kernel CL_mul_mat_Ab_Bi_8x4; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_gemm_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; cl_kernel kernel_gemv_noshuffle_q4_1_f32; cl_kernel kernel_gemm_noshuffle_q4_1_f32; - cl_kernel kernel_mul_mm_q8_0_f32_8x4; - cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemm_noshuffle_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q8_0_f32; cl_kernel kernel_gemv_noshuffle_q4_k_f32; cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; @@ -2578,21 +2572,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv_general { - #include "gemv_noshuffle_general.cl.h" + #include "gemv_noshuffle_q4_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q4_0_f32.cl"); #endif - backend_ctx->program_CL_gemv_general = build_program_from_source( + cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2606,20 +2601,21 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv { - #include "gemv_noshuffle.cl.h" + #include "gemv_noshuffle_q4_0_f32_spec.cl.h" }; #else - const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle_q4_0_f32_spec.cl"); #endif - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 2048, 16384 @@ -2630,12 +2626,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 5504, 44032 @@ -2646,12 +2643,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 16000, 128000 @@ -2663,12 +2661,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2676,13 +2675,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemm { - #include "mul_mat_Ab_Bi_8x4.cl.h" + #include "gemm_noshuffle_q4_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); + const std::string kernel_src_CL_gemm = read_file("gemm_noshuffle_q4_0_f32.cl"); #endif - backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2767,14 +2767,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_q8_8x4_gemm { - #include "mul_mm_q8_0_f32_8x4.cl.h" + const std::string kernel_src { + #include "gemm_noshuffle_q8_0_f32.cl.h" }; #else - const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl"); + const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); #endif - backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err)); + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2790,16 +2791,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv_general { - #include "gemv_noshuffle_general_q8_0_f32.cl.h" + #include "gemv_noshuffle_q8_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl"); + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q8_0_f32.cl"); #endif cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -4937,164 +4938,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image if (use_adreno_kernels(backend_ctx, tensor)) { - // <----------------------------------------------------------------------------------> // - // start transpose - // <----------------------------------------------------------------------------------> // - int M = tensor->ne[1]; // ne01 - int K = tensor->ne[0]; // ne00 - - //For matrix-vector multiplication kernel, we assume K is a multiple of 32 - GGML_ASSERT(K % 32 == 0); - //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 - GGML_ASSERT(M % 4 == 0); - - // transpose is out of place, so we need to allocate transposed buffers - // <----------------------------------------------------------------------------------> // - // use sub_buffer of max buffer size instead - - size_t q_size_bytes = K * M / 8 * sizeof(float); - backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes); - - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - bool K_tile_trans = true; - if ((K / 32) % 4 != 0){ - K_tile_trans =false; - } - - size_t d_size_bytes = M * (K / 32) * 2; - backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes); - - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - // <----------------------------------------------------------------------------------> // - - - // create images from the buffers - // <----------------------------------------------------------------------------------> // - cl_mem q_d_image1D; - cl_mem d_d_image1D; - cl_mem qT_d_image1D; - cl_mem dT_d_image1D; - - cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - cl_image_desc img_desc_1d; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - if (K_tile_trans) { - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32 / 4; - } else { - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; - } - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - // <----------------------------------------------------------------------------------> // - - // set up and call the transpose kernels - // <----------------------------------------------------------------------------------> // - // weights - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_16; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // scales - int height_s = M / 4; - int width_s = K / 32 / 4; - - kernel = backend_ctx->kernel_transpose_16; - if (!K_tile_trans) { - kernel = backend_ctx->kernel_transpose_16_4x1; - width_s = K / 32; - } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); - - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // + int M = tensor->ne[1]; + int K = tensor->ne[0]; - // copy transposed buffer contents to original buffers - // <----------------------------------------------------------------------------------> // - // weights - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + GGML_ASSERT(K % 32 == 0); - // scales - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // - - // deallocate transpose buffers - // <----------------------------------------------------------------------------------> // - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); - - // deallocate temporary images - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); - // <----------------------------------------------------------------------------------> // - // end transpose - // <----------------------------------------------------------------------------------> // + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -5820,8 +5672,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_kernels(backend_ctx, tensor)) { - cl_int err; - cl_kernel kernel; + ggml_cl_buffer buf_trans_q; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -5833,46 +5686,12 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); - cl_mem buf_trans_q; - cl_mem buf_trans_d; - - CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_q, NULL, &err), err)); - CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_d, NULL, &err), err)); - - kernel = backend_ctx->kernel_transpose_16_buf; - - // transpose q back - cl_int stride_k_q = K/4; - size_t local_size_q[3] = {64, 1, 1}; - size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1}; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q)); - - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_q, local_size_q, 0, NULL, NULL)); - - // transpose scales back - cl_int stride_k_d = K/32; - size_t local_size_d[3] = {64, 1, 1}; - size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1}; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d)); - - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_d, local_size_d, 0, NULL, NULL)); + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); - // unpack - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -5880,25 +5699,15 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; - kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, NULL)); - - // read back to host - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - - CL_CHECK(clReleaseMemObject(data_device)); - CL_CHECK(clReleaseMemObject(buf_trans_q)); - CL_CHECK(clReleaseMemObject(buf_trans_d)); - + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); return; } #endif @@ -10073,6 +9882,235 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q4_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; + } + + int r2 = 1; + int r3 = 1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=4; + local_work_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size_t[0]=1; + local_work_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10495,7 +10533,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t img_desc.buffer = b_sub_buf; CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + kernel = backend_ctx->kernel_gemv_noshuffle_q8_0_f32; int r2 = 1; int r3 = 1; @@ -10585,7 +10623,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); // gemm - kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32; int padded_N = N + padding; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); @@ -11195,8 +11233,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11219,28 +11257,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; - - const int ne0 = dst ? dst->ne[0] : 0; - const int ne1 = dst ? dst->ne[1] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); int r2 = ne12/ne02; int r3 = ne13/ne03; @@ -11256,8 +11278,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_context context = backend_ctx->context; - if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 && // dst is wrapped with image1d_buffer, the size limit applies, also src0 @@ -11284,340 +11304,52 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { + // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require + // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that + // limit, so the check is omitted. - // init CL objects - // <--------------------------------------------> // - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - cl_buffer_region region; - cl_mem A_image1d = nullptr; - cl_mem B_image1d = nullptr; - cl_mem B_sub_buffer = nullptr; - cl_mem C_d = nullptr; - // for B transpose - cl_mem B_d = nullptr; - cl_mem B_d_input_image = nullptr; - // <--------------------------------------------> // - - // define matrix dimensions - // <--------------------------------------------> // - int M = ne01; - int N = ne1; - int K = ne00; - int padding; - // <--------------------------------------------> // - - // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require - // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that - // limit, so the check is omitted. + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_0_f32_adreno(backend, src0, src1, dst); + return; + } - // q4_1 x fp32 - if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { + // q4_1 x fp32 + if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst); return; - } - - // iq4_nl x fp32 - if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); - return; - } + } - // q8_0 x fp32 - if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && - enable_adreno_trans_weight(backend_ctx, src0)) { - ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); return; - } + } + + // q8_0 x fp32 + if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && + enable_adreno_trans_weight(backend_ctx, src0)) { + ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + return; + } - // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; - } - - // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); - return; - } - - // q5_K x fp32 - if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); - return; - } - - // q4_0 x fp32 - if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { - // TODO: remove duplicate definitions of image description + format -- move to top - - // create an image for A - // <--------------------------------------------> // - if (N == 1) { - img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; - } else { - img_fmt_1d = { CL_R, CL_FLOAT}; - } - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q4_0->q; - A_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - - - // create a sub_buffer for B - // <--------------------------------------------> // - region.origin = (extra1->offset); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer( - extra1->data_device, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - - // transpose activation for Skyler's gemm - if (N != 1) { - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; - - //how much padding to add - padding = 0; - if (extra_elements > 0){ - padding = 8 - extra_elements; - } - - // Specify the starting offset (in bytes) - region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) - region.size = K * (N + padding) * sizeof(float)/2; - backend_ctx->prealloc_act_trans.allocate(context, region.size); - - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; - cl_image_desc image_desc_B_d_input = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * N / 4), - 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } - }; - B_d_input_image = clCreateImage( - context, - 0, - &image_format_B_d_input, - &image_desc_B_d_input, - NULL, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); - - int height_B = N/4; - if (height_B == 0) { - height_B = 1; - } - int width_B = K/4; - int padded_height_B = (N + padding)/4; - - kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - - size_t local_size_t[2] = { 1, 16 }; - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=4; - local_size_t[1]=8; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_size_t[0]=1; - local_size_t[1]=8; - } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } - - size_t global_size_t[2] = { - static_cast(width_B), - static_cast(padded_height_B) - }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); - } else { - // no need to transpose B in other cases - // create an image for B from sub_buffer - // <--------------------------------------------> // - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_width = K * N / 4; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - } - - // choose gemm or gemv kernel - // <--------------------------------------------> // - if (N == 1) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - if (M == 4096 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - } else if (M == 4096 && K == 11008) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - } else if (M == 11008 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - } else if (M == 32000 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; - } - } else { - kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; - } - // <--------------------------------------------> // - - // set kernel args - // <--------------------------------------------> // - cl_uint k_arg = 0; - - if (N == 1) { - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); - } else { - region.origin = extrad->offset; // Specify the starting offset (in bytes) - region.size = M * N * sizeof(float); // Specify the size of the sub-buffer - C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - int padded_N = ne1 + padding; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding - } - // <--------------------------------------------> // - - // choose workgroup size - // <--------------------------------------------> // - size_t global_work_size[3] = { - 64, static_cast((M+63)/64), static_cast((N+31)/32)}; - size_t local_work_size[3] = {64, 2, 4}; - - global_work_size[0] = (size_t)(ceil((float)ne1/8)); - global_work_size[1] = (size_t)(ne01/4); - global_work_size[2] = (size_t)(1); - - local_work_size[0] = (size_t)(1); //4x32 for FP32 - local_work_size[1] = (size_t)(128); - local_work_size[2] = (size_t)(1); - - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 1; - local_work_size[1] = 128; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; } - if (N == 1) { - size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; // localsize - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; - - global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; } - // <--------------------------------------------> // - - // enqueue kernel with profiling - // <--------------------------------------------> // - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // <--------------------------------------------> // - - // deallocate sub buffers and images - // <--------------------------------------------> // - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); - if (N != 1) { - CL_CHECK(clReleaseMemObject(B_d)); - CL_CHECK(clReleaseMemObject(B_d_input_image)); - CL_CHECK(clReleaseMemObject(C_d)); + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; } - // <--------------------------------------------> // - - return; - } } // if (ne01 && ne1) #endif // GGML_OPENCL_USE_ADRENO_KERNELS diff --git a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl similarity index 99% rename from ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl index ecb577b9933..159378049fb 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl @@ -17,7 +17,7 @@ REQD_SUBGROUP_SIZE_128 #endif -kernel void kernel_mul_mat_Ab_Bi_8x4( +kernel void kernel_gemm_noshuffle_q4_0_f32( global const ushort * src0_q, // quantized A global const half * src0_d, // A scales __read_only image1d_buffer_t src1, // B (1d image) diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl rename to ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl index 51ce2121ce2..7f06a22a2cb 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl @@ -11,7 +11,7 @@ REQD_SUBGROUP_SIZE_128 #endif -kernel void kernel_mul_mm_q8_0_f32_8x4( +kernel void kernel_gemm_noshuffle_q8_0_f32( global const uint * src0_q, global const half * src0_d, __read_only image1d_buffer_t src1, diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl index 469d3edef00..10683206919 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -238,21 +238,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl index ee5c79f000d..571a375da7f 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -232,21 +232,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl From a6d678954ab9935383cfb74a2b64735537a4abb8 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Tue, 5 May 2026 11:43:03 -0500 Subject: [PATCH 019/289] Hexagon: Process M-tail rows on HMX instead of HVX (llama/22724) * hex-mm: process m-tail rows on HMX instead of HVX * hmx-mm: unroll and optimize padded activation loop --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 51 ++++++++++++++++++---- ggml/src/ggml-hexagon/htp/matmul-ops.c | 36 +++------------ 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 2666a78a96a..9e8c9966e04 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, // activations : fp32 -> fp16 static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - const bool next_row_valid = (r + 1) < n_rows; - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); for (int c = 0; c < k_block; c += 32) { HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; + const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); @@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). const size_t m_block_cost = (size_t) n * 3; const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + m_block_cost, n_block_cost, &M_BLOCK_SIZE, &N_BLOCK_SIZE, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; @@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds if (m >= 128) { size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && hmx_ceil_div((size_t) n, nc) >= 2) { @@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); @@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, /*per_n=*/3 * vec_dot_size, /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*per_mn=*/sizeof(__fp16), + hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); @@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co /*per_n=*/3 * vec_dot_size, // W + S0 + S1 /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch /*per_mn=*/sizeof(__fp16), // O - m, n, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n, /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index a0c265132c8..2461ae617fa 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // M alignment: when M > 32 but not 32-aligned, we split into - // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). - // When M <= 32 and not 32-aligned, fall back entirely to HVX. + // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) + // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; - const int m_tail = m_total % 32; - const int m_hmx = m_total - m_tail; + const int m_hmx = m_total & ~31; // 0 when M < 32 if (m_hmx == 0) { return op_matmul_hvx(octx); @@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) { int k = (int) src0->ne[0]; // inner dimension int n = (int) src0->ne[1]; // weight columns - // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- int ret = -1; // Row strides in elements. For compact tensors these equal k; for @@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, - .m = m_hmx, + .m = m_total, .k = k, .n = n, .act_stride = act_stride, @@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) { } else { ret = hmx_mat_mul_permuted_w16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_hmx, k, n, act_stride, wgt_stride); + m_total, k, n, act_stride, wgt_stride); } } else { ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_hmx, k, n, (int) src0->type); + m_total, k, n, (int) src0->type); } if (ret != 0) { @@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul(octx); } - // --- Phase 2: HVX on the remaining m_tail rows --- - if (m_tail > 0) { - // copy of src1 and dst - struct htp_tensor src1_tail = *src1; - struct htp_tensor dst_tail = *dst; - - src1_tail.ne[1] = m_tail; // only tail rows - dst_tail.ne[1] = m_tail; // only tail rows - - // Offset activation and dst pointers past the HMX-processed rows. - // Use nb[1] (row stride in bytes) to compute the byte offset. - src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; - dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; - - octx->src[1] = &src1_tail; - octx->dst = &dst_tail; - - FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); - return op_matmul_hvx(octx); - } - return 0; #endif // HTP_HAS_HMX } From 3613268bc73ffaaf7c5d768ecfbfee24125e67b0 Mon Sep 17 00:00:00 2001 From: fl0rianr Date: Wed, 6 May 2026 07:12:48 +0200 Subject: [PATCH 020/289] ggml : use `CL_DEVICE_GLOBAL_MEM_SIZE` as memory estimate for OpenCL --fit (llama/22688) * ggml : report estimated OpenCL memory for --fit Signed-off-by: Florian Reinle * ggml : estimated OpenCL memory backend integrated Signed-off-by: Florian Reinle --------- Signed-off-by: Florian Reinle --- ggml/src/ggml-opencl/ggml-opencl.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8c7bf98c16f..d344bde0fe3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -389,6 +389,7 @@ struct ggml_backend_opencl_context { ADRENO_GPU_GEN adreno_gen; cl_int alignment; + size_t global_mem_size; size_t max_alloc_size; size_t max_workgroup_size; bool fp16_support; @@ -3386,6 +3387,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); + clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); @@ -6356,11 +6360,16 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ } static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // no memory to report - *free = 0; - *total = 0; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; - GGML_UNUSED(dev); + static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; + + // OpenCL does not provide reliable currently-free device memory. + // Use total/global memory as a best-effort upper bound. + // Improved safety: Reduce by a 1GiB extra margin for common --fit + *total = backend_ctx->global_mem_size; + *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { From d3f16afcf57d7f6b9ae7aa469dbebd7113dc1dc1 Mon Sep 17 00:00:00 2001 From: zzzzwc Date: Wed, 6 May 2026 15:41:14 +0800 Subject: [PATCH 021/289] ggml-cpu: fuse RMS_NORM + MUL on CPU backend (llama/22423) --- ggml/src/ggml-cpu/ggml-cpu.c | 53 +++++++++++++++++++++++- ggml/src/ggml-cpu/ops.cpp | 78 ++++++++++++++++++++++++++++-------- ggml/src/ggml-cpu/ops.h | 1 + 3 files changed, 115 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2d6cc1fcd46..8b7acafdaa8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2965,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan( return cplan; } + +// Try to fuse the current node with subsequent nodes for better performance. +// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied. +static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards + +static int ggml_cpu_try_fuse_ops( + const struct ggml_cgraph * cgraph, + const int node_n, + const struct ggml_compute_params * params, + const struct ggml_cplan * cplan) { + + if (ggml_cpu_disable_fusion || cplan->use_ref) { + return 0; + } + + struct ggml_tensor * node = cgraph->nodes[node_n]; + + if (node->op == GGML_OP_RMS_NORM) { + // RMS_NORM + MUL fusion + const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL }; + if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) { + struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1]; + const struct ggml_tensor * mul_w = (mul_node->src[0] == node) + ? mul_node->src[1] : mul_node->src[0]; + if (node->src[0]->type == GGML_TYPE_F32 && + mul_node->type == GGML_TYPE_F32 && + mul_w->type == GGML_TYPE_F32 && + mul_w->ne[0] == node->ne[0] && + mul_w->nb[0] == sizeof(float)) { + + ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node); + return 1; + } + } + } + + return 0; +} + static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_threadpool * tp = state->threadpool; @@ -3001,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } - ggml_compute_forward(¶ms, node); + // TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time + // Try fused ops, fall back to normal compute + const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan); + if (n_fused > 0) { + node_n += n_fused; + } else { + ggml_compute_forward(¶ms, node); + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { @@ -3763,6 +3809,11 @@ void ggml_cpu_init(void) { ggml_init_riscv_arch_features(); #endif + { + const char * env = getenv("GGML_CPU_DISABLE_FUSION"); + ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1); + } + is_first_call = false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 211f1ba1b2f..6bc8dc150ce 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm( // ggml_compute_forward_group_rms_norm +// fusion kinds that can be combined with the rms_norm computation in a single pass. +// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...). +enum ggml_rms_norm_fuse_op { + GGML_RMS_NORM_FUSE_OP_NONE, + GGML_RMS_NORM_FUSE_OP_MUL, +}; + +template static void ggml_compute_forward_rms_norm_f32( const ggml_compute_params * params, - ggml_tensor * dst) { + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_fused = nullptr) { - const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0 = dst_rms_norm->src[0]; + const ggml_tensor * src1 = nullptr; + ggml_tensor * dst = dst_rms_norm; + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0]; + dst = dst_fused; + } GGML_ASSERT(ggml_are_same_shape(src0, dst)); @@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - + memcpy(&eps, dst_rms_norm->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); // TODO: optimize @@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32( const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float sum = 0.0; + // worth switching to explicit SIMD? for (int64_t i00 = 0; i00 < ne00; i00++) { sum += (ggml_float)(x[i00] * x[i00]); } - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - + const float mean = sum/ne00; const float scale = 1.0f/sqrtf(mean + eps); // if you hit this, likely you got an inf somewhere earlier assert(scale > 0.0f); - ggml_vec_scale_f32(ne00, y, scale); + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + const int64_t i11 = i01 % ne11; + const int64_t i12 = i02 % ne12; + const int64_t i13 = i03 % ne13; + const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + y[i00] = x[i00] * scale * w[i00]; + } + } else { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, y, scale); + } } } } @@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_rms_norm_f32(params, dst); + ggml_compute_forward_rms_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass. +// This avoids materializing the intermediate rms_norm result in memory. +void ggml_compute_forward_rms_norm_mul_fused( + const ggml_compute_params * params, + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_mul) { + + GGML_ASSERT(dst_mul != nullptr); + GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm); + + const ggml_tensor * src0 = dst_rms_norm->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, dst_rms_norm, dst_mul); } break; default: { diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 29efdeee37f..7398e561894 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul); void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); From 4395364605738226787117697e7227ea514fcd3e Mon Sep 17 00:00:00 2001 From: pl752 Date: Thu, 7 May 2026 18:09:25 +0500 Subject: [PATCH 022/289] ggml-cpu: Optimized risc-v cpu q1_0 dot --- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/riscv/quants.c | 98 +++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 595ded09f03..b0391a67c88 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -203,7 +203,6 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index d3278d6489f..ee69e5ab5e5 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -480,6 +480,104 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +#if defined(__riscv_v) +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl256(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 1, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m1(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool8_t is_not_zero = __riscv_vlm_v_b8(x[ib].qs + 4 * k, vl32); + + const vint8m1_t qy = __riscv_vle8_v_i8m1(yb->qs, vl32); + const vint8m1_t neg_qy = __riscv_vneg_v_i8m1(qy, vl32); + const vint8m1_t sy = __riscv_vmerge_vvm_i8m1(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m1_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl128(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 2, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m2(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool4_t is_not_zero = __riscv_vlm_v_b4(x[ib].qs + 4 * k, vl32); + + const vint8m2_t qy = __riscv_vle8_v_i8m2(yb->qs, vl32); + const vint8m2_t neg_qy =__riscv_vneg_v_i8m2(qy, vl32); + const vint8m2_t sy = __riscv_vmerge_vvm_i8m2(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m2_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined(__riscv_v) + assert(nrc == 1); + + const size_t vlen_bits = __riscv_vlenb() * 8; + + if (vlen_bits >= 256) { + ggml_vec_dot_q1_0_q8_0_vl256(n, s, vx, vy); + } else if (vlen_bits >= 128) { + ggml_vec_dot_q1_0_q8_0_vl128(n, s, vx, vy); + } else { + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + } +#else + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); From bd693bb1ebf5cb8bc987b1a43dd8bf31a7b65edc Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Thu, 7 May 2026 08:51:33 -0700 Subject: [PATCH 023/289] sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET (llama/22149) * sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET Signed-off-by: Chun Tao * Fix abort during test-backend-ops Signed-off-by: Todd Malsbary * Regenerate ops.md Signed-off-by: Todd Malsbary * Add scope_dbg_print to newly added SYCL ops. Also add scope_dbg_print to existing ssm_conv op. Signed-off-by: Todd Malsbary --------- Signed-off-by: Chun Tao Signed-off-by: Todd Malsbary Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/cumsum.cpp | 148 +++++++++++++++++++++ ggml/src/ggml-sycl/cumsum.hpp | 5 + ggml/src/ggml-sycl/diag.cpp | 67 ++++++++++ ggml/src/ggml-sycl/diag.hpp | 5 + ggml/src/ggml-sycl/fill.cpp | 55 ++++++++ ggml/src/ggml-sycl/fill.hpp | 5 + ggml/src/ggml-sycl/gated_delta_net.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 37 +++++- ggml/src/ggml-sycl/solve_tri.cpp | 172 +++++++++++++++++++++++++ ggml/src/ggml-sycl/solve_tri.hpp | 8 ++ ggml/src/ggml-sycl/ssm_conv.cpp | 7 +- ggml/src/ggml-sycl/ssm_scan.cpp | 156 ++++++++++++++++++++++ ggml/src/ggml-sycl/ssm_scan.hpp | 5 + 13 files changed, 669 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-sycl/cumsum.cpp create mode 100644 ggml/src/ggml-sycl/cumsum.hpp create mode 100644 ggml/src/ggml-sycl/diag.cpp create mode 100644 ggml/src/ggml-sycl/diag.hpp create mode 100644 ggml/src/ggml-sycl/fill.cpp create mode 100644 ggml/src/ggml-sycl/fill.hpp create mode 100644 ggml/src/ggml-sycl/solve_tri.cpp create mode 100644 ggml/src/ggml-sycl/solve_tri.hpp create mode 100644 ggml/src/ggml-sycl/ssm_scan.cpp create mode 100644 ggml/src/ggml-sycl/ssm_scan.hpp diff --git a/ggml/src/ggml-sycl/cumsum.cpp b/ggml/src/ggml-sycl/cumsum.cpp new file mode 100644 index 00000000000..c1c5fe4fe4a --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.cpp @@ -0,0 +1,148 @@ +#include "cumsum.hpp" +#include "common.hpp" + +#include + +#define SYCL_CUMSUM_BLOCK_SIZE 256 + +static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) { + return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus()); +} + +static void cumsum_f32_kernel( + const float * __restrict__ src, float * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t d1, const int64_t d2, const int64_t d3, + const sycl::nd_item<3> & item, float * smem) { + + const int tid = item.get_local_id(2); + const int block_size = item.get_local_range(2); + const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE; + const int warps_per_block = block_size / WARP_SIZE; + + float * s_vals = smem; + float * s_warp_sums = smem + block_size; + float * s_carry = smem + block_size + warps_per_block; + + if (tid == 0) { + s_carry[0] = 0.0f; + } + item.barrier(sycl::access::fence_space::local_space); + + const int64_t i3 = item.get_group(0); + const int64_t i2 = item.get_group(1); + const int64_t i1 = item.get_group(2); + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3; + + constexpr int num_unroll = 4; + float temp[num_unroll]; + + for (int64_t i = 0; i < ne00; i += num_unroll * block_size) { + int64_t idx = i + tid * num_unroll; + + temp[0] = (idx < ne00 ? src_row[idx] : 0.0f); +#pragma unroll + for (int j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } + } + + float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f; + + val = warp_prefix_inclusive_sum_f32(val, item); + s_vals[tid] = val; + + if (lane == WARP_SIZE - 1) { + s_warp_sums[warp] = val; + } + item.barrier(sycl::access::fence_space::local_space); + + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum_f32(w, item); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; + } + if (tid == warps_per_block - 1) { + s_carry[1] = inc; + } + } + item.barrier(sycl::access::fence_space::local_space); + + float carry = s_carry[0]; + float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + final_offset; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + s_carry[0] += s_carry[1]; + } + } +} + +inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src_d = static_cast(src0->data); + float * dst_d = static_cast(dst->data); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const size_t ts = sizeof(float); + const int64_t s01 = src0->nb[1] / ts; + const int64_t s02 = src0->nb[2] / ts; + const int64_t s03 = src0->nb[3] / ts; + const int64_t d1 = dst->nb[1] / ts; + const int64_t d2 = dst->nb[2] / ts; + const int64_t d3 = dst->nb[3] / ts; + + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + int block_size = num_warps * WARP_SIZE; + block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE); + const int warps_per_block = block_size / WARP_SIZE; + const int smem_size = block_size + warps_per_block + 2; + + const sycl::range<3> grid(ne03, ne02, ne01); + const sycl::range<3> block(1, 1, block_size); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<3>(grid * block, block), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03, + s01, s02, s03, d1, d2, d3, + item, get_pointer(smem_acc)); + }); + }); +} + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_cumsum(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/cumsum.hpp b/ggml/src/ggml-sycl/cumsum.hpp new file mode 100644 index 00000000000..f1a564472c5 --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/diag.cpp b/ggml/src/ggml-sycl/diag.cpp new file mode 100644 index 00000000000..c4264fee342 --- /dev/null +++ b/ggml/src/ggml-sycl/diag.cpp @@ -0,0 +1,67 @@ +#include "diag.hpp" +#include "common.hpp" + +#define SYCL_DIAG_BLOCK_SIZE 256 + +template +static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src, + const int64_t ne0, const int64_t ne1, + const int64_t ne2, const int64_t ne3, + const int64_t total_elements, + const sycl::nd_item<1> & item) { + const int64_t i = item.get_global_id(0); + if (i >= total_elements) { + return; + } + + const int64_t i0 = i % ne0; + const int64_t i1 = (i / ne0) % ne1; + const int64_t i2 = (i / (ne0 * ne1)) % ne2; + const int64_t i3 = i / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + dst[dst_idx] = src[batch_idx * ne0 + i0]; + } else { + dst[dst_idx] = T(0); + } + + (void)ne3; +} + +inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[1] == 1); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const void * src0_d = src0->data; + void * dst_d = dst->data; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + diag_kernel(static_cast(dst_d), + static_cast(src0_d), + ne0, ne1, ne2, ne3, n_elems, item); + }); +} + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_diag(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/diag.hpp b/ggml/src/ggml-sycl/diag.hpp new file mode 100644 index 00000000000..20d7ce4895d --- /dev/null +++ b/ggml/src/ggml-sycl/diag.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/fill.cpp b/ggml/src/ggml-sycl/fill.cpp new file mode 100644 index 00000000000..28e618e4ef5 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.cpp @@ -0,0 +1,55 @@ +#include "fill.hpp" +#include "common.hpp" + +#define SYCL_FILL_BLOCK_SIZE 256 + +template +static void fill_kernel(T * dst, const int64_t k, const T value, + const sycl::nd_item<1> & item) { + const int64_t i = (int64_t)item.get_global_id(0); + if (i >= k) { + return; + } + dst[i] = value; +} + +inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE; + void * dst_d = dst->data; + + switch (dst->type) { + case GGML_TYPE_F32: + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast(dst_d), k, value, item); + }); + break; + case GGML_TYPE_F16: + { + sycl::half h_value = sycl::half(value); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast(dst_d), k, h_value, item); + }); + } + break; + default: + GGML_ABORT("unsupported type"); + } +} + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); + ggml_sycl_op_fill(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/fill.hpp b/ggml/src/ggml-sycl/fill.hpp new file mode 100644 index 00000000000..b2adb94ff52 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp index a3308ee8763..350b4ce2f66 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.hpp +++ b/ggml/src/ggml-sycl/gated_delta_net.hpp @@ -5,4 +5,5 @@ #include "common.hpp" #include "ggml.h" +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f06147eeeb8..29ecedb5de9 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -54,7 +54,12 @@ #include "ggml-sycl/set.hpp" #include "ggml-sycl/ssm_conv.hpp" #include "ggml-sycl/sycl_hw.hpp" - +#include "ggml-sycl/ssm_scan.hpp" +#include "ggml-sycl/fill.hpp" +#include "ggml-sycl/cumsum.hpp" +#include "ggml-sycl/diag.hpp" +#include "ggml-sycl/solve_tri.hpp" +#include "ggml-sycl/gated_delta_net.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -4394,6 +4399,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_sycl_ssm_conv(ctx, dst); break; + case GGML_OP_SSM_SCAN: + ggml_sycl_ssm_scan(ctx, dst); + break; + case GGML_OP_FILL: + ggml_sycl_fill(ctx, dst); + break; + case GGML_OP_CUMSUM: + ggml_sycl_cumsum(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_sycl_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_sycl_solve_tri(ctx, dst); + break; case GGML_OP_ROLL: ggml_sycl_roll(ctx, dst); break; @@ -5104,6 +5124,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; + case GGML_OP_SSM_SCAN: + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports (d_state == 128 || d_state == 256) && d_head % WARP_SIZE == 0) + return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0; + } else { + // TODO Mamba-1 not yet ported to SYCL + return false; + } + case GGML_OP_FILL: + case GGML_OP_CUMSUM: + case GGML_OP_DIAG: + return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K; case GGML_OP_FLASH_ATTN_EXT: return ggml_sycl_flash_attn_ext_supported(device, op); default: diff --git a/ggml/src/ggml-sycl/solve_tri.cpp b/ggml/src/ggml-sycl/solve_tri.cpp new file mode 100644 index 00000000000..39326deee44 --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.cpp @@ -0,0 +1,172 @@ +#include "solve_tri.hpp" +#include "common.hpp" +#include + +template +static void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const int64_t ne02, [[maybe_unused]] const int64_t ne03, + const int64_t nb02, const int64_t nb03, + const int64_t nb12, const int64_t nb13, + const int64_t nb2, const int64_t nb3, + const int n_arg, const int k_arg, + const sycl::nd_item<2> & item, float * sA) { + + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = item.get_group(1); + const int lane = item.get_local_id(1) % WARP_SIZE; + const int col_idx = item.get_local_id(0); + + if (col_idx >= k) { + return; + } + + const int64_t i03 = batch_idx / ne02; + const int64_t i02 = batch_idx % ne02; + + const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); + const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); + + const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1); + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + const int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + +#pragma unroll + for (int row = 0; row < nrows_low; ++row) { + float sum = 0.0f; + if (lane < row) { + sum += sA[row * n + lane] * x_low; + } + sum = warp_reduce_sum(sum); + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; + } + } +} + +static void solve_tri_f32_mkl(dpct::queue_ptr stream, + const float * A, float * X, + int n, int k, + int64_t ne02, [[maybe_unused]] int64_t ne03, + int64_t nb02, [[maybe_unused]] int64_t nb03, + int64_t nb2, [[maybe_unused]] int64_t nb3) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + const int64_t stride_a = nb02 / sizeof(float); + const int64_t stride_x = nb2 / sizeof(float); + + oneapi::mkl::blas::trsm_batch( + *stream, + oneapi::mkl::side::right, + oneapi::mkl::uplo::upper, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::diag::nonunit, + k, n, alpha, + A, n, stride_a, + X, k, stride_x, + total_batches); +} + +inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K); + + const float * A_d = static_cast(src0->data); + const float * B_d = static_cast(src1->data); + float * X_d = static_cast(dst->data); + + if (X_d != B_d) { + const int64_t total_elements = (int64_t)n * k * ne02 * ne03; + stream->memcpy(X_d, B_d, total_elements * sizeof(float)); + } + + const int64_t nb02 = src0->nb[2]; + const int64_t nb03 = src0->nb[3]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; + + const int64_t total_batches = ne02 * ne03; + + if (n <= 2 * WARP_SIZE && k <= 32) { + const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE; + const sycl::range<2> grid(1, total_batches); + const sycl::range<2> block(k, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03, + nb02, nb03, nb12, nb13, nb2, nb3, + n, k, item, get_pointer(smem_acc)); + }); + }); + } else { + solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3); + } +} + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_solve_tri(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/solve_tri.hpp b/ggml/src/ggml-sycl/solve_tri.hpp new file mode 100644 index 00000000000..c7c34cfa2bb --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "common.hpp" + +#define SYCL_SOLVE_TRI_MAX_N 64 +#define SYCL_SOLVE_TRI_MAX_K 64 + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp index eea9a73d67e..e55223586a1 100644 --- a/ggml/src/ggml-sycl/ssm_conv.cpp +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -63,7 +63,7 @@ static void kernel_ssm_conv( }); } -void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +inline void ggml_sycl_op_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -125,3 +125,8 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { throw; } } + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_ssm_conv(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.cpp b/ggml/src/ggml-sycl/ssm_scan.cpp new file mode 100644 index 00000000000..ae652981384 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.cpp @@ -0,0 +1,156 @@ +#include "ssm_scan.hpp" +#include "common.hpp" + +template +static void ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok, + const sycl::nd_item<2> & item) { + + const int lane = item.get_local_id(1) % WARP_SIZE; + const int warp = item.get_local_id(1) / WARP_SIZE; + const int warp_idx = item.get_group(1) * c_factor + warp; + const int seq_idx = item.get_group(0); + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[c_factor]; + float state_sum = 0.0f; + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; + } + + for (int64_t i = 0; i < n_tok; i++) { + const float dt_val = dt_warp[i * stride_dt]; + const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val); + + state_sum = 0.0f; + const float dA = sycl::exp(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; +#pragma unroll + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; + } + + state_sum = warp_reduce_sum(state_sum); + + if (lane == 0) { + y_warp[i * stride_y] = state_sum; + } + } + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; + } +} + +static void ssm_scan_f32_sycl( + const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, + dpct::queue_ptr stream) { + + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + GGML_ASSERT(src3_nb1 == sizeof(float)); + if (d_state == 128) { + constexpr int threads = 128; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<128 / WARP_SIZE, 128>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else if (d_state == 256) { + constexpr int threads = 256; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<256 / WARP_SIZE, 256>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else { + GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)"); + } +} + +inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + const ggml_tensor * src6 = dst->src[6]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t nc = src0->ne[0]; + const int64_t nr = src0->ne[1]; + const int64_t nh = src1->ne[1]; + const int64_t ng = src4->ne[1]; + const int64_t n_t = src1->ne[2]; + const int64_t n_s = src1->ne[3]; + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + ssm_scan_f32_sycl( + static_cast(src0->data), static_cast(src1->data), + static_cast(src2->data), static_cast(src3->data), + static_cast(src4->data), static_cast(src5->data), + static_cast(src6->data), static_cast(dst->data), + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); +} + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); + ggml_sycl_op_ssm_scan(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.hpp b/ggml/src/ggml-sycl/ssm_scan.hpp new file mode 100644 index 00000000000..1f9731fb6fd --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst); From 7774fe2c8d19813a3ad2d74ce33ba51ae01996da Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Thu, 7 May 2026 11:00:20 -0700 Subject: [PATCH 024/289] opencl: add opfilter regex for debugging (llama/22782) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d344bde0fe3..e5a5d42f6fb 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #undef MIN #undef MAX @@ -396,6 +397,8 @@ struct ggml_backend_opencl_context { bool has_vector_subgroup_broadcast; bool disable_fusion; + std::regex *opfilter = nullptr; // regex of ops to not claim + bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; @@ -3494,6 +3497,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter); + } + dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -4143,6 +4152,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + // reject ops that match the opfilter regex + if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) { + return false; + } + switch (op->op) { case GGML_OP_NONE: return true; From 5fd75cda3fec6b87f724a5784479f8ff9348a7d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 21:43:40 +0300 Subject: [PATCH 025/289] llama : fix device state save/load (llama/22805) --- ggml/src/ggml-metal/ggml-metal.cpp | 44 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 35774254983..a1003b3acff 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -87,17 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, } static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, - /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_shared_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, + /* .reset = */ NULL, }; // private buffer @@ -163,17 +163,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { From 6e91ed3b338f221c51c02439abf4c4e952ebe306 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Fri, 8 May 2026 03:59:29 +0800 Subject: [PATCH 026/289] CUDA: batch out_prod inner loop with cublasSgemmStridedBatched (llama/22651) * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: add cublasSgemmStridedBatched mapping for HIP and MUSA backends --- ggml/src/ggml-cuda/out-prod.cu | 30 +++++++++++++++++++++++------- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index c9b2b699c6a..499903d09b1 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; - // TODO batched matrix multiplication - for (int64_t i3 = 0; i3 < ne3; ++i3) { - for (int64_t i2 = 0; i2 < ne2; ++i2) { + if (dps2 == 1 && ne2 > 1) { + // src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM + GGML_ASSERT(ne2 <= std::numeric_limits::max()); + const int batch_count = (int) ne2; + for (int64_t i3 = 0; i3 < ne3; ++i3) { CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, - src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + &alpha, src0_d + (i3/dps3)*s03, lda, s02, + src1_d + i3 *s13, ldb, s12, + &beta, dst_d + i3 *s3, ldc, s2, + batch_count)); + } + } else { + // Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2 + // with non-uniform stride; would need cublasSgemmBatched with pointer arrays). + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } } } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index e5d363c65d1..5e0e22c7fc2 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -48,6 +48,7 @@ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm +#define cublasSgemmStridedBatched hipblasSgemmStridedBatched #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t #define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 940c34a9fb2..99e8fa3703e 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -32,6 +32,7 @@ #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm +#define cublasSgemmStridedBatched mublasSgemmStridedBatched #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasGetStatusString From ef77e10404ade9b4e74f5a546d974c9463defe94 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Thu, 7 May 2026 21:17:07 -0700 Subject: [PATCH 027/289] opencl: add q4_0 MoE GEMM for Adreno (llama/22731) * Q4_0 MoE CLC pass sanity check * release program * opencl: fix whitespace * opencl: remove unused cl_program * opencl: break #if block to make it more clear * opencl: adjust format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 296 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 86 +++++ .../kernels/gemm_moe_q4_0_f32_ns.cl | 252 +++++++++++++++ .../kernels/gemv_moe_q4_0_f32_ns.cl | 116 +++++++ 5 files changed, 743 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0a45a4daa13..ffde6a4f063 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -102,6 +102,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_q8_0_f32_flat mul_mv_id_mxfp4_f32 mul_mv_id_mxfp4_f32_flat + gemm_moe_q4_0_f32_ns + gemv_moe_q4_0_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e5a5d42f6fb..4e6f6fb43d2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -542,6 +542,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; @@ -600,6 +601,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; + cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -950,6 +952,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); @@ -2884,6 +2888,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3657,11 +3695,14 @@ struct ggml_tensor_extra_cl_q4_0 { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; size_q = 0; size_d = 0; @@ -4926,17 +4967,53 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_0 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -4952,7 +5029,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; // transpose the weights and scales - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image if (use_adreno_kernels(backend_ctx, tensor)) { @@ -4966,10 +5043,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // Transpose d as ushort transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; - } if (tensor->type == GGML_TYPE_Q4_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -5689,6 +5764,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_trans4_ns; + + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { ggml_cl_buffer buf_trans_q; ggml_cl_buffer buf_trans_d; @@ -12811,6 +12916,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, // subgroup mat vec switch (src0->type) { case GGML_TYPE_Q4_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } // fallback to generic Q4_0 MoE kernel + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; if (backend_ctx->gpu_family == INTEL) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c1ad46f4435..c87450dc49e 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -190,6 +190,92 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +kernel void kernel_convert_block_q4_0_trans4_ns( + global struct block_q4_0 * src0, + __global uint * dst_q, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_0_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global struct block_q4_0 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_1 // Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..02290c17eb1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -0,0 +1,252 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_0(q4, a_f16, scale) \ + a_f16.s0 = (half)((q4.s0 & 0x000F) - 8) * scale; \ + a_f16.s1 = (half)(((q4.s0 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s2 = (half)(((q4.s0 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s3 = (half)(((q4.s0 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s4 = (half)((q4.s1 & 0x000F) - 8) * scale; \ + a_f16.s5 = (half)(((q4.s1 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s6 = (half)(((q4.s1 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s7 = (half)(((q4.s1 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s8 = (half)((q4.s2 & 0x000F) - 8) * scale; \ + a_f16.s9 = (half)(((q4.s2 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.sa = (half)(((q4.s2 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sb = (half)(((q4.s2 & 0xF000) >> 12) - 8) * scale; \ + a_f16.sc = (half)((q4.s3 & 0x000F) - 8) * scale; \ + a_f16.sd = (half)(((q4.s3 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.se = (half)(((q4.s3 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sf = (half)(((q4.s3 & 0xF000) >> 12) - 8) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_0_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q4_0 block + uint s_offset = s_sub_offset + get_global_id(0); + half s = src0_d[s_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..6f4d3f53216 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -0,0 +1,116 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_0_to_fp32_packed8(ushort2 q4x8) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) - 8); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) - 8); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) - 8); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) - 8); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) - 8); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) - 8); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) - 8); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) - 8); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_0_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + float8 fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From eb38a02de13c2778c18a514a4c93f3e49dda016d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 7 May 2026 22:43:04 -0700 Subject: [PATCH 028/289] ggml: update SCHED_DEBUG output to use ggml_op_desc() (llama/22825) --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d9f8aaec52f..4e36909f45e 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -965,7 +965,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { From 803424ac5a03c1b05945b432f1ea94e1e1b5b1bb Mon Sep 17 00:00:00 2001 From: miyan <1138989048@qq.com> Date: Fri, 8 May 2026 15:35:22 +0800 Subject: [PATCH 029/289] vulkan: fix spv shadowing (llama/22760) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 423e01dbff1..0a7931002ab 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2149,11 +2149,11 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for // separate shader variants compiled with -DRTE16. - std::vector spv; + std::vector spirv; if (device->float_controls_rte_fp16) { const uint32_t* spv_words = reinterpret_cast(spv_data); size_t word_count = spv_size / sizeof(uint32_t); - spv.assign(spv_words, spv_words + word_count); + spirv.assign(spv_words, spv_words + word_count); // Find insertion points respecting SPIR-V layout order: // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... @@ -2163,9 +2163,9 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin size_t exec_insert_pos = pos; uint32_t entry_point_id = 0; - while (pos < spv.size()) { - uint32_t opcode = spv[pos] & spv::OpCodeMask; - uint32_t len = spv[pos] >> spv::WordCountShift; + while (pos < spirv.size()) { + uint32_t opcode = spirv[pos] & spv::OpCodeMask; + uint32_t len = spirv[pos] >> spv::WordCountShift; if (len == 0) break; if (opcode == spv::OpCapability) { @@ -2174,7 +2174,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } else if (opcode == spv::OpExtension) { ext_insert_pos = pos + len; } else if (opcode == spv::OpEntryPoint) { - entry_point_id = spv[pos + 2]; + entry_point_id = spirv[pos + 2]; exec_insert_pos = pos + len; } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { exec_insert_pos = pos + len; @@ -2189,7 +2189,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin // OpExecutionMode %entrypoint RoundingModeRTE 16 uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; - spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + spirv.insert(spirv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); // OpExtension "SPV_KHR_float_controls" const char ext_str[] = "SPV_KHR_float_controls"; @@ -2197,13 +2197,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::vector extension(1 + ext_str_words, 0); extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; memcpy(&extension[1], ext_str, sizeof(ext_str)); - spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end()); + spirv.insert(spirv.begin() + ext_insert_pos, extension.begin(), extension.end()); // OpCapability RoundingModeRTE uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; - spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + spirv.insert(spirv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); - shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data()); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); } pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); From ea459fba9d7c88bb2137f32607d7a28f8538e1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 8 May 2026 10:09:38 +0200 Subject: [PATCH 030/289] CUDA: lower-case PCI bus id, standardize for ggml (llama/22820) --- ggml/include/ggml-backend.h | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index d0c7e5a1be0..b6f73739809 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -169,7 +169,7 @@ extern "C" { // device type enum ggml_backend_dev_type type; // device id - // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // for PCI devices, this should be the lower-case PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:c1:00.0") // if the id is unknown, this should be NULL const char * device_id; // device capabilities diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8d21b2267f5..925a9ffe04c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5434,6 +5434,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { char pci_bus_id[32] = {}; CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i)); dev_ctx->pci_bus_id = pci_bus_id; + for (char & c : dev_ctx->pci_bus_id) { + c = std::tolower(c); + } dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { From 184f1a1383e3e6917527e723f8b4ccf9fd571550 Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 8 May 2026 11:44:09 +0200 Subject: [PATCH 031/289] cuda: fuse snake activation (mul, sin, sqr, mul, add) (llama/22667) * cuda: fuse snake activation (mul, sin, sqr, mul, add) Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. Add test_snake_fuse comparing CPU naive vs CUDA fused across F32 / F16 / BF16. * cuda: address review feedback from @am17an Use ggml_cuda_cast for F32/F16/BF16 conversions and rename kernel_snake to snake_kernel to match upstream conventions. * cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an * Update tests/test-backend-ops.cpp Co-authored-by: Aman Gupta * cuda: snake fusion check add->type matches x->type Address review feedback from @am17an * cuda: snake fusion check add->type matches x->type Moved for readability (equivalent) Address review feedback from @am17an --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++ ggml/src/ggml-cuda/snake.cu | 72 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/snake.cuh | 8 ++++ 3 files changed, 110 insertions(+) create mode 100644 ggml/src/ggml-cuda/snake.cu create mode 100644 ggml/src/ggml-cuda/snake.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 925a9ffe04c..4df1b930882 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -39,6 +39,7 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/roll.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/snake.cuh" #include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" @@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 2; } + // Snake activation: y = x + sin(a*x)^2 * inv_b + // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add + if (ggml_can_fuse_subgraph(cgraph, i, + { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD }, + { i + 4 })) { + const ggml_tensor * mul0 = cgraph->nodes[i]; + const ggml_tensor * sqr = cgraph->nodes[i + 2]; + const ggml_tensor * mul1 = cgraph->nodes[i + 3]; + ggml_tensor * add = cgraph->nodes[i + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + // closure check: the trailing add must read the same x as the leading mul + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; + + if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); + return 4; + } + } + // multi-(add or mul) if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; diff --git a/ggml/src/ggml-cuda/snake.cu b/ggml/src/ggml-cuda/snake.cu new file mode 100644 index 00000000000..384638c1f47 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cu @@ -0,0 +1,72 @@ +#include "snake.cuh" +#include "convert.cuh" + +// Fused Snake activation: y = x + sin^2(a * x) * inv_b +// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C] +// Supports F32, F16, BF16 data with F32 compute. + +template +static __global__ void snake_kernel( + const T * __restrict__ x, + const float * __restrict__ a, + const float * __restrict__ inv_b, + T * __restrict__ dst, + const int total, + const uint3 T_len_fastdiv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv); + + const float xi = ggml_cuda_cast(x[idx]); + const float s = sinf(a[c] * xi); + dst[idx] = ggml_cuda_cast(xi + s * s * inv_b[c]); +} + +// Internal launcher with explicit x/a/inv_b/dst tensors. +// Shared by the public op (reads dst->src) and the fusion path (explicit args). +static void launch_snake(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + const float * a_d = (const float *)a->data; + const float * inv_b_d = (const float *)inv_b->data; + + const int T = (int)x->ne[0]; + const int C = (int)x->ne[1]; + const int total = T * C; + const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T); + + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = ctx.stream(); + + switch (x->type) { + case GGML_TYPE_F32: { + snake_kernel<<>>( + (const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_F16: { + snake_kernel<<>>( + (const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_BF16: { + snake_kernel<<>>( + (const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv); + } break; + default: + GGML_ABORT("snake: unsupported type"); + } +} + +// Fusion entry: caller supplies x/a/inv_b explicitly from the matched +// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + launch_snake(ctx, x, a, inv_b, dst); +} diff --git a/ggml/src/ggml-cuda/snake.cuh b/ggml/src/ggml-cuda/snake.cuh new file mode 100644 index 00000000000..7f6f1cb3b41 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +// Fusion entry point. Caller supplies x/a/inv_b explicitly. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst); From e0573051c6e7f814db4353429980482f663a0057 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Fri, 8 May 2026 13:41:40 -0700 Subject: [PATCH 032/289] Feature hexagon l2 norm (llama/22816) * L2_NORM Updates * Addressed PR Comments * ggml-hexagon: add L2_NORM HVX kernel for Hexagon backend * hex-unary: remove supported_unary_nc since the outer loop is the same for all unary ops --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 9 ++- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 81 ++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index df4ed101464..8ddd1915c83 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2420,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // TODO: add support for non-contiguous elements within a row - if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) { + // dst must be contiguous; src0 may be non-contiguous + if (!ggml_is_contiguous(dst)) { return false; } @@ -2791,6 +2791,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; @@ -3253,6 +3254,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_L2_NORM: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 66a3150c1a0..ef96ad38278 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -83,6 +83,8 @@ enum htp_op_code { HTP_OP_FILL, HTP_OP_DIAG, HTP_OP_SOLVE_TRI, + HTP_OP_L2_NORM, + HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 49c1a15b344..e18f1a0e61e 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_UNARY_SIGMOID: case HTP_OP_UNARY_NEG: case HTP_OP_UNARY_EXP: + case HTP_OP_L2_NORM: return op_unary(octx); case HTP_OP_UNARY_SILU: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 819cdc49bd9..26a0e0bd793 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src, } } +// --- L2_NORM HVX kernel --- +// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row. +// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers +// using rsqrt + inverse to avoid scalar extraction. +static void hvx_fast_l2_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_v = hvx_vec_splat_f32(0.0f); + + const int nvec = num_elems / VLEN_FP32; + const int nloe = num_elems % VLEN_FP32; + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Include tail elements in the sum-of-squares using a predicate mask + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers. + // hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction. + HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum) + HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum) + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon) + HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon) + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + } + + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + +static void l2_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon); + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_UNARY_SOFTPLUS: softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_L2_NORM: + l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } @@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_UNARY_SOFTPLUS: op_type = "softplus-f32"; break; + case HTP_OP_L2_NORM: + op_type = "l2norm-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); From 892f786a653d19c2f474b3a9c56d9d4d7be2fb1f Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 17:05:22 -0700 Subject: [PATCH 033/289] sycl: support non-contiguous input in PAD op (llama/22148) Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +- ggml/src/ggml-sycl/pad.cpp | 54 ++++++++++++++++---------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 29ecedb5de9..c3ac281067a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5104,11 +5104,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ACC: return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_PAD: - // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985 if (ggml_get_op_params_i32(op, 8) != 0) { return false; } - return ggml_is_contiguous(op->src[0]); + return true; case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp index f989c5e4b8b..ee93bb51801 100644 --- a/ggml/src/ggml-sycl/pad.cpp +++ b/ggml/src/ggml-sycl/pad.cpp @@ -13,7 +13,8 @@ //#include "common.hpp" #include "pad.hpp" -static void pad_f32(const float * src, float * dst, +static void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -27,7 +28,6 @@ static void pad_f32(const float * src, float * dst, return; } - // operation const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && @@ -37,12 +37,8 @@ static void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + - i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -50,20 +46,19 @@ static void pad_f32(const float * src, float * dst, } } -static void pad_f32_sycl(const float *src, float *dst, const int lp0, - const int rp0, const int lp1, const int rp1, - const int lp2, const int rp2, const int lp3, - const int rp3, const int ne0, const int ne1, - const int ne2, const int ne3, +static void pad_f32_sycl(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3, dpct::queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; - dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3); + sycl::range<3> grid(ne2 * ne3, ne1, num_blocks); stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::nd_range<3>(grid * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, - ne2, ne3, item_ct1); + pad_f32(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, item_ct1); }); } @@ -71,22 +66,27 @@ void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; - dpct::queue_ptr stream = ctx.stream(); + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; - const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; - const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; - const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; - const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = src0->nb[0] / ts; + const size_t s01 = src0->nb[1] / ts; + const size_t s02 = src0->nb[2] / ts; + const size_t s03 = src0->nb[3] / ts; - pad_f32_sycl(src0_d, dst_d, + const int32_t lp0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *)(dst->op_params))[7]; + + pad_f32_sycl(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } From 42aea65eda1b4100bedb42ea1cc7344da9f1054a Mon Sep 17 00:00:00 2001 From: Yanzhao Wang Date: Fri, 8 May 2026 17:12:04 -0700 Subject: [PATCH 034/289] hexagon: add HTP kernel for GGML_OP_GATED_DELTA_NET (llama/22837) Implement the Gated Delta Net recurrence on HVX with: - 4-row fused kernels for PP (prompt processing) path - 8-row fused kernels for TG (token generation) path, reducing K/Q/gate vector reload overhead by 2x - Separate PP/TG thread functions for I-cache isolation - VTCM state scratchpad with DMA in/out for TG single-cycle access - Vectorized gate exp via hvx_exp_f32 --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 111 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + .../ggml-hexagon/htp/gated-delta-net-ops.c | 955 ++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 1045 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 8ddd1915c83..d3c125dbc3d 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2261,6 +2261,58 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return true; } +static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * q = op->src[0]; + const struct ggml_tensor * k = op->src[1]; + const struct ggml_tensor * v = op->src[2]; + const struct ggml_tensor * g = op->src[3]; + const struct ggml_tensor * beta = op->src[4]; + const struct ggml_tensor * state = op->src[5]; + const struct ggml_tensor * dst = op; + + if (!q || !k || !v || !g || !beta || !state) { + return false; + } + + if (q->type != GGML_TYPE_F32 || k->type != GGML_TYPE_F32 || v->type != GGML_TYPE_F32 || + g->type != GGML_TYPE_F32 || beta->type != GGML_TYPE_F32 || state->type != GGML_TYPE_F32 || + dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous_rows(q) || !ggml_is_contiguous_rows(k) || !ggml_is_contiguous_rows(v) || + !ggml_is_contiguous(g) || !ggml_is_contiguous(beta) || !ggml_is_contiguous(state) || + !ggml_is_contiguous(dst)) { + return false; + } + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { + return false; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] <= 0 || k->ne[1] <= 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] <= 0 || k->ne[3] <= 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { + return false; + } + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { + return false; + } + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { + return false; + } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -2777,33 +2829,34 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { static htp_op_code op_remap_to_htp(const ggml_tensor * t) { switch (t->op) { - case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; - case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; - case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; - case GGML_OP_MUL: return HTP_OP_MUL; - case GGML_OP_ADD: return HTP_OP_ADD; - case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; - case GGML_OP_SUB: return HTP_OP_SUB; - case GGML_OP_DIV: return HTP_OP_DIV; - case GGML_OP_CPY: return HTP_OP_CPY; - case GGML_OP_CONT: return HTP_OP_CPY; - case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; - case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; - case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; - case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; - case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; - case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; - case GGML_OP_SCALE: return HTP_OP_SCALE; - case GGML_OP_SQR: return HTP_OP_SQR; - case GGML_OP_SQRT: return HTP_OP_SQRT; - case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; - case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; - case GGML_OP_ROPE: return HTP_OP_ROPE; - case GGML_OP_REPEAT: return HTP_OP_REPEAT; - case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; - case GGML_OP_FILL: return HTP_OP_FILL; - case GGML_OP_DIAG: return HTP_OP_DIAG; - case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_GATED_DELTA_NET: return HTP_OP_GATED_DELTA_NET; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; + case GGML_OP_DIAG: return HTP_OP_DIAG; + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3341,6 +3394,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_GATED_DELTA_NET: + supp = ggml_hexagon_supported_gated_delta_net(sess, op); + break; + case GGML_OP_CUMSUM: supp = ggml_hexagon_supported_cumsum(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 7c9e4cda5f1..bcadac11f95 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(${HTP_LIB} SHARED fill-ops.c diag-ops.c solve-tri-ops.c + gated-delta-net-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c new file mode 100644 index 00000000000..2e84badc9b7 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -0,0 +1,955 @@ +#include +#include +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +#define HTP_GDN_MAX_SV 128 + +struct htp_gdn_context { + struct htp_ops_context * octx; + uint32_t rows_per_thread; + size_t state_bytes; + bool use_vtcm; + uint8_t * vtcm_state_base; + size_t vtcm_state_per_thread; +}; + +static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, + const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, + const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + float scale, const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + const HVX_Vector vscale = hvx_vec_splat_f32(scale); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict src, + const float * restrict scale, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict mul, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + float mul, const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict src, const float * restrict scale, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]); + const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]); + const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]); + const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[4] __attribute__((aligned(128))); + + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + const uint32_t iv1 = ir % H; + const uint32_t iv3 = ir / H; + + const uint32_t iq1 = iv1 % q->ne[1]; + const uint32_t ik1 = iv1 % k->ne[1]; + const uint32_t iq3 = iv3 / rq3; + const uint32_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + + memcpy(s_out, s_in, gctx->state_bytes); + float * s_work = s_out; + + float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; + + for (uint32_t t = 0; t < n_tokens; ++t) { + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); + + memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); + memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } + + attn_data += (uint64_t) S_v * H; + } + } +} + +static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_seqs = v->ne[3]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[8] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + + uint8_t * spad = NULL; + if (gctx->use_vtcm) { + spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; + } + + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + const uint32_t iv1 = ir % H; + const uint32_t iv3 = ir / H; + + const uint32_t iq1 = iv1 % q->ne[1]; + const uint32_t ik1 = iv1 % k->ne[1]; + const uint32_t iq3 = iv3 / rq3; + const uint32_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_work; + + if (spad) { + dma_queue_push(dma, dma_make_ptr(spad, s_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + dma_queue_pop(dma); + s_work = (float *) spad; + } else { + s_work = s_out; + memcpy(s_work, s_in, gctx->state_bytes); + } + + float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; + + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); + + memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); + memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row4 = s_work + (uint64_t) (j + 4) * S_v; + float * row5 = s_work + (uint64_t) (j + 5) * S_v; + float * row6 = s_work + (uint64_t) (j + 6) * S_v; + float * row7 = s_work + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + float local_delta_b[8] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 8; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 8; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row4 = s_work + (uint64_t) (j + 4) * S_v; + float * row5 = s_work + (uint64_t) (j + 5) * S_v; + float * row6 = s_work + (uint64_t) (j + 6) * S_v; + float * row7 = s_work + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + float local_delta_b[8] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 8; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 8; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } + + if (spad) { + dma_queue_push(dma, dma_make_ptr(s_out, spad), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + dma_queue_pop(dma); + } + } +} + +int op_gated_delta_net(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + if (!q || !k || !v || !g || !beta || !state || !dst) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 || + g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 || + dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + + if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { + return HTP_STATUS_NO_SUPPORT; + } + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { + return HTP_STATUS_NO_SUPPORT; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + return HTP_STATUS_NO_SUPPORT; + } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_gdn_context gctx; + gctx.octx = octx; + gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads; + gctx.state_bytes = (size_t) S_v * S_v * sizeof(float); + + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + + gctx.use_vtcm = false; + gctx.vtcm_state_base = NULL; + gctx.vtcm_state_per_thread = 0; + + if (n_tokens == 1 && octx->ctx->vtcm_base) { + size_t vtcm_total = state_aligned * octx->n_threads; + if (octx->ctx->vtcm_size >= vtcm_total) { + gctx.use_vtcm = true; + gctx.vtcm_state_base = octx->ctx->vtcm_base; + gctx.vtcm_state_per_thread = state_aligned; + } + } + + if (n_tokens == 1) { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index e9c563ca887..92f02eac6e3 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -106,5 +106,6 @@ int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); +int op_gated_delta_net(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index ef96ad38278..6203e3848b9 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -84,6 +84,7 @@ enum htp_op_code { HTP_OP_DIAG, HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, + HTP_OP_GATED_DELTA_NET, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e18f1a0e61e..fa1e0698f4a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -594,6 +594,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_GATED_DELTA_NET: + return op_gated_delta_net(octx); + case HTP_OP_INVALID: break; From 197c62c10b0fd0452a740704cbba3257526df04c Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Fri, 8 May 2026 20:28:29 -0700 Subject: [PATCH 035/289] Add flash attention MMA / Tiles to support MiMo-V2.5 (llama/22812) * mimo-v2.5: add flash attention mma/tiles for for d_kq=192 d_v=128 * mimo-v2.5: follow (256, 256) fattn templates * mimo-v2.5: cleanup comments * mimo-v2.5: further comment cleanup * mimo-v2.5: address PR feedback fix GQA handling check for other dangling 320/576 carveouts and mirror them for 192 Add to backend ops test so new paths are covered --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 9 +++++ ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 40 ++++++++++++++++++- ggml/src/ggml-cuda/fattn.cu | 33 +++++++++++++-- ...ttn-mma-f16-instance-ncols1_1-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_2-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_4-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + .../fattn-tile-instance-dkq192-dv128.cu | 5 +++ .../template-instances/generate_cu_files.py | 13 ++++-- 13 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3f01e858de7..43e22c5e5ee 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); @@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } + if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) { + NO_DEVICE_CODE; + return; + } #ifdef VOLTA_MMA_AVAILABLE if (ncols1*ncols2 < 32) { NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index d60634cc0e9..c8281497d14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; + case 192: { + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst); + } break; case 256: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 585f2c22853..7b0a5e5cf49 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) @@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) @@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) @@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512 && DKQ != 320) { + if constexpr (DKQ == 192) { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8"); + } + + if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80); extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(192, 128); extern DECL_FATTN_TILE_CASE(256, 256); extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8256591b21d..e045b04f727 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 128); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); break; + case 192: { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + GGML_ASSERT(V->ne[0] == 128); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst); + } else { + GGML_ASSERT(gqa_ratio % 8 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst); + } + } break; case 256: GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); @@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 192: + if (V->ne[0] != 128 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 8 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 320: if (V->ne[0] != 256 || !gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; @@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { @@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } @@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0da..b2661b93162 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 22d383173f3..6ae77bec895 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208cd2..fd41e71b142 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 84b674cd05a..9f4bef11a44 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e2369c..cc41fa52f13 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index 5906398db91..859bea5c525 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 4bc60d62f91..c975ce6b9b7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu new file mode 100644 index 00000000000..b571cca0df2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(192, 128); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 5e9a1cb2eb3..af05a9eff71 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,10 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576] + +# DKQ -> DV override for asymmetric head dims. +HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128} TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +65,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -85,15 +88,17 @@ def get_short_name(long_quant_name): if head_size_kq == 72: continue # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5 + continue if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - if head_size_kq not in (320, 576) and ncols2 in (16, 32): + if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32): continue - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: From 63f788320628650ca377d77f95504ca8538f6d46 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 22:42:40 -0700 Subject: [PATCH 036/289] sycl: Battlemage AOT build via spir64_gen + MMQ subgroup annotations (llama/22147) * sycl: Battlemage AOT build via spir64_gen + MMQ subgroup annotations Signed-off-by: Chun Tao * Remove unneeded/unnecessary comments and annotations The MMQ subgroup annotations added are on functions gated behind ggml_sycl_supports_mmq(). Revisit the need for these annotations when that function changes. --------- Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/CMakeLists.txt | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8e589fa238d..8f44c6ed080 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -135,7 +135,11 @@ endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") add_compile_definitions(GGML_SYCL_WARP_SIZE=16) - target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) + if (NOT GGML_SYCL_DEVICE_ARCH) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) + else() + message(STATUS "Skipping -ze-intel-greater-than-4GB-buffer-required for spir64_gen AOT") + endif() # Link against Intel oneMKL if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") @@ -160,7 +164,15 @@ if (GGML_SYCL_HOST_MEM_FALLBACK) endif() if (GGML_SYCL_DEVICE_ARCH) - target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) - target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) + message(STATUS "GGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} (AOT via spir64_gen)") + target_compile_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) + target_link_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) endif() - From 3542894544e53a429e6b6f110fbabfb1382d1898 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 22:48:07 -0700 Subject: [PATCH 037/289] sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path (llama/22152) * sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path Signed-off-by: Chun Tao * Remove duplicate definitions --------- Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/convert.cpp | 29 ++++++++- ggml/src/ggml-sycl/dequantize.hpp | 57 ++++++++++++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 52 ++++++++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 30 +++++++++- ggml/src/ggml-sycl/quants.hpp | 25 ++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 98 +++++++++++++++++++++++-------- 6 files changed, 265 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 67b9c06f3e4..576f19d79ae 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -252,6 +252,23 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template +static void dequantize_row_q5_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(K_SCALE_SIZE), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -643,7 +660,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; @@ -718,7 +739,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 19fa88680d6..2324bfacd22 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -537,6 +537,63 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q5_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + uint8_t * scales_local, const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + +#if QK_K == 256 + // assume 64 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid / 16; // 0...3 + const int64_t ir = tid % 16; // 0...15 + const int64_t is = 2 * il; + + dst_t * y = yy + ib * QK_K + 64 * il + 2 * ir; + + const uint8_t * base = static_cast(vx); + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales (K_SCALE_SIZE per block)] [dm (half2 per block)] + const size_t qs_offset = ib * (QK_K / 2); + const size_t qh_offset = n_blocks * (QK_K / 2) + ib * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + ib * K_SCALE_SIZE; + const size_t dm_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + n_blocks * K_SCALE_SIZE + ib * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * qh_ptr = base + qh_offset; + const uint8_t * scales_ptr = base + scales_offset; + const ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + const uint8_t * ql = qs_ptr + 32 * il + 2 * ir; + const uint8_t * qh = qh_ptr + 2 * ir; + + if (tid < K_SCALE_SIZE) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2 * il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + GGML_UNUSED(ib); GGML_UNUSED(tid); GGML_UNUSED(yy); GGML_UNUSED(scales_local); GGML_UNUSED(n_blocks); + GGML_ABORT("Q5_K reorder dequantize not supported for QK_K != 256"); +#endif +} + template static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c3ac281067a..f86ff3e9466 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3303,6 +3303,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q8_0: return true; case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return !g_ggml_sycl_prioritize_dmmv; default: @@ -3325,6 +3326,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return true; default: @@ -3541,6 +3543,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d return true; } +static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q5_K) == 0); + GGML_ASSERT(offset % sizeof(block_q5_K) == 0); + + const int nblocks = size / sizeof(block_q5_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q5_K * x = (const block_q5_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); @@ -3607,6 +3657,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); + case GGML_TYPE_Q5_K: + return reorder_qw_q5_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: return reorder_qw_q6_k(data_device, size, 0, stream); default: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 8fa2198f35a..49998f13ba8 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -839,6 +839,26 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -1125,6 +1145,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; @@ -1145,7 +1166,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens } break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 1f5b62740a8..806028ef3a3 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -79,6 +79,31 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI5_K; + static constexpr uint32_t qr = QR5_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales] [dm] + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 2); + auto qh_offset = n_blocks * (QK_K / 2) + block_index * (QK_K / 8); + return { qs_offset, qh_offset }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * K_SCALE_SIZE, + total_qs_bytes + nblocks * K_SCALE_SIZE + block_index * sizeof(ggml_half2) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t { struct traits { static constexpr uint32_t qk = QK_K; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9253168e5ea..d7770047424 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -357,38 +357,31 @@ template <> struct reorder_vec_dot_q_sycl { using q8_0_block = ggml_sycl_reordered::block_q_t; using q8_0_traits = typename q8_0_block::traits; - __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) { - int sumi = 0; - -#pragma unroll - for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { - // Q8_0 values are signed int8, no nibble extraction needed - // Direct dp4a: each int packs 4 int8 values - sumi = dpct::dp4a(v[i], u[i], sumi); - } - - const sycl::float2 ds8f = ds8.convert(); - - // Q8_0 has no bias term (values are signed), so just scale - return d8_0 * sumi * ds8f.x(); - } - __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int & iqs) { - const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first; - const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); - int v[q8_0_traits::vdr_mmvq]; - int u[q8_0_traits::vdr_mmvq]; + const uint8_t * base = static_cast(vbq); + const int8_t * qs = reinterpret_cast(base + ibx_offset.first); + const ggml_half d = *reinterpret_cast(base + d_offset.first); + + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; #pragma unroll for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { - v[i] = get_int_from_int8(bq8_0, iqs + i); + v[i] = get_int_from_int8(qs, iqs + i); u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds); - }; + int sumi = 0; +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::half2 ds_values = *q8_1_ds; + return static_cast(d) * static_cast(ds_values[0]) * sumi; + } }; static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, @@ -481,6 +474,65 @@ template <> struct reorder_vec_dot_q_sycl { } }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q5_K; + + using q5_k_block = ggml_sycl_reordered::block_q_t; + using q5_k_traits = typename q5_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset.first; // low 4 bits + const uint8_t * qh_base = base + ibx_offset.second; // high bit + const uint8_t * scs = base + d_offset.first; + const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); + + const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2)); + const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + int vl[2]; + int vh[2]; + int u[2 * QR5_K]; + float d8[QR5_K]; + + vl[0] = ql_ptr[0]; + vl[1] = ql_ptr[4]; + + vh[0] = qh_ptr[0] >> bq8_offset; + vh[1] = qh_ptr[4] >> bq8_offset; + + uint16_t aux[2]; + const int j = (QR5_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR5_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, *dms, d8); + } +}; + template <> struct reorder_vec_dot_q_sycl { static constexpr ggml_type gtype = GGML_TYPE_Q6_K; From 25f543175d0652204eacd643864a09a8b5fd39fe Mon Sep 17 00:00:00 2001 From: Devedse <2350015+devedse@users.noreply.github.com> Date: Sat, 9 May 2026 07:50:24 +0200 Subject: [PATCH 038/289] Add BF16 support to GET_ROWS operation (llama/21391) Add GGML_TYPE_BF16 to the SYCL backend's GET_ROWS operation, both in supports_op and in the kernel dispatch. This fixes a performance regression where models using BF16 embedding tensors (e.g., Gemma4's per_layer_token_embd.weight) fall back to CPU for the GET_ROWS op, causing a full GPU-to-CPU tensor transfer every token. The fix reuses the existing get_rows_sycl_float template with sycl::ext::oneapi::bfloat16, matching the pattern already used for sycl::half (F16) and float (F32). --- ggml/src/ggml-sycl/getrows.cpp | 4 ++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 1 + 2 files changed, 5 insertions(+) diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 03f8dd90748..ca457454775 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -183,6 +183,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_BF16: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::ext::oneapi::bfloat16 *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_F32: get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f86ff3e9466..b6e705cdf3a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4974,6 +4974,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { switch (op->src[0]->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: From 8c7efe885cb38b9617ca952ebf7bc19d79bd6ffa Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Sat, 9 May 2026 15:30:39 +0900 Subject: [PATCH 039/289] SYCL: reduce allocation overhead during flash attention (llama/22732) * SYCL: reduce allocation overhead during flash attention * tidy up whitespace * add a note about the flag * move ggml_sycl_fattn_* into fattn-buffers.hpp * refactor implementation into fattn-buffers.cpp * move new_fattn_kv_buffers back into ggml-sycl.cpp --- ggml/src/ggml-sycl/common.hpp | 16 +++++++ ggml/src/ggml-sycl/fattn-buffers.cpp | 56 +++++++++++++++++++++++++ ggml/src/ggml-sycl/fattn-buffers.hpp | 63 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/fattn-common.hpp | 6 ++- ggml/src/ggml-sycl/ggml-sycl.cpp | 41 ++++++++++++++++++ 5 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-sycl/fattn-buffers.cpp create mode 100644 ggml/src/ggml-sycl/fattn-buffers.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 5abf2290651..eec36e8db9a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -25,6 +25,7 @@ #include "presets.hpp" #include "type.hpp" #include "sycl_hw.hpp" +#include "fattn-buffers.hpp" namespace syclexp = sycl::ext::oneapi::experimental; @@ -404,12 +405,16 @@ struct ggml_backend_sycl_context { std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; std::unordered_map>> scratchpad_map; + std::unique_ptr fattn_bufs[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + static std::unique_ptr new_fattn_kv_buffers(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -421,6 +426,17 @@ struct ggml_backend_sycl_context { return pool(device); } + ggml_sycl_fattn_kv_buffers & fattn_buffers(int device) { + if (fattn_bufs[device] == nullptr) { + fattn_bufs[device] = new_fattn_kv_buffers(stream(device, 0), device); + } + return *fattn_bufs[device]; + } + + ggml_sycl_fattn_kv_buffers & fattn_buffers() { + return fattn_buffers(device); + } + #ifdef GGML_SYCL_GRAPH std::unique_ptr> exec_graph = nullptr; #endif diff --git a/ggml/src/ggml-sycl/fattn-buffers.cpp b/ggml/src/ggml-sycl/fattn-buffers.cpp new file mode 100644 index 00000000000..46cf6d551f1 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.cpp @@ -0,0 +1,56 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "common.hpp" + +sycl::half * ggml_sycl_fattn_kv_buffers::kv_buffer::ensure_half(size_t n_elems) { + const size_t need_bytes = n_elems * sizeof(sycl::half); + + if (capacity >= need_bytes) { + return ptr; + } + + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(qptr->wait())); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + ptr = nullptr; + capacity = 0; + } + + size_t cap = 0; + while (cap < need_bytes) { + cap += CHUNK_SIZE; + } + + void * dev_ptr; + SYCL_CHECK( + CHECK_TRY_ERROR(dev_ptr = sycl::malloc_device( + cap, *qptr))); + + if (!dev_ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, cap); + GGML_ABORT("fattn buffer alloc failed"); + } + + ptr = static_cast(dev_ptr); + capacity = cap; + return ptr; +} + +ggml_sycl_fattn_kv_buffers::kv_buffer::~kv_buffer() { +#ifdef DEBUG_SYCL_POOL + GGML_LOG_INFO("ggml_sycl_fattn_kv_buffer[%d]: %.2f MiB\n", device, capacity / 1024.0 / 1024.0); +#endif + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + } +} diff --git a/ggml/src/ggml-sycl/fattn-buffers.hpp b/ggml/src/ggml-sycl/fattn-buffers.hpp new file mode 100644 index 00000000000..c00461de620 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.hpp @@ -0,0 +1,63 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_FATTN_BUFFERS_HPP +#define GGML_SYCL_FATTN_BUFFERS_HPP + +#include + +typedef sycl::queue *queue_ptr; + +struct ggml_sycl_fattn_kv_buffers { + // buffers grow in chunks of this size + static constexpr size_t CHUNK_SIZE = 16ull << 20; // 16 MiB + + struct kv_buffer { + kv_buffer(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + ~kv_buffer(); + + kv_buffer(const kv_buffer &) = delete; + kv_buffer & operator=(const kv_buffer &) = delete; + + sycl::half * ensure_half(size_t n_elems); + + private: + sycl::half * ptr = nullptr; + size_t capacity = 0; + queue_ptr qptr = nullptr; + [[maybe_unused]] int device = 0; + }; + + kv_buffer K; + kv_buffer V; + + ggml_sycl_fattn_kv_buffers(queue_ptr qptr, int device) : K(qptr, device), V(qptr, device) {} + + ggml_sycl_fattn_kv_buffers(const ggml_sycl_fattn_kv_buffers &) = delete; + ggml_sycl_fattn_kv_buffers & operator=(const ggml_sycl_fattn_kv_buffers &) = delete; +}; + +/** + * Imitates `ggml_sycl_pool_alloc` to keep the code calling alloc unchanged. + */ +struct ggml_sycl_fattn_alloc { + ggml_sycl_fattn_kv_buffers::kv_buffer & buf; + sycl::half * ptr = nullptr; + + explicit ggml_sycl_fattn_alloc(ggml_sycl_fattn_kv_buffers::kv_buffer & buf_) : buf(buf_) {} + + sycl::half * alloc(size_t n_elems) { + ptr = buf.ensure_half(n_elems); + return ptr; + } +}; +#endif diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index ed00d03c3b6..03f0c2623c8 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -5,6 +5,7 @@ #include "common.hpp" #include "convert.hpp" #include "vecdotq.hpp" +#include "fattn-buffers.hpp" #include "ggml.h" @@ -918,12 +919,13 @@ void launch_fattn( GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); ggml_sycl_pool & pool = ctx.pool(); + ggml_sycl_fattn_kv_buffers & fbuf = ctx.fattn_buffers(); dpct::queue_ptr main_stream = ctx.stream(); const int id = ggml_sycl_get_device(); const int nsm = ggml_sycl_info().devices[id].nsm; - ggml_sycl_pool_alloc K_f16(pool); - ggml_sycl_pool_alloc V_f16(pool); + ggml_sycl_fattn_alloc K_f16(fbuf.K); + ggml_sycl_fattn_alloc V_f16(fbuf.V); ggml_sycl_pool_alloc KV_max(pool); ggml_sycl_pool_alloc dst_tmp(pool); ggml_sycl_pool_alloc dst_tmp_meta(pool); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b6e705cdf3a..e7768b8bf61 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1286,6 +1286,23 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} ~ggml_sycl_pool_leg() { +#ifdef DEBUG_SYCL_POOL + int n_cached = 0; + size_t bytes_cached = 0; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr != nullptr) { + ++n_cached; + bytes_cached += buffer_pool[i].size; + } + } + GGML_LOG_INFO("%s: %d buffers, cached = %.2f MiB\n", __func__, + n_cached, bytes_cached / 1024.0 / 1024.0); + const auto slots = format_slots_in_alloc_order(); + if (!slots.empty()) { + GGML_LOG_INFO("%s: slots MiB: %s\n", __func__, slots.c_str()); + } +#endif + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { @@ -1296,6 +1313,26 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { GGML_ASSERT(pool_size == 0); } +#ifdef DEBUG_SYCL_POOL + std::string format_slots_in_alloc_order() const { + std::string line; + char buf[32]; + bool first = true; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr == nullptr) { + continue; + } + if (!first) { + line += '/'; + } + first = false; + snprintf(buf, sizeof(buf), "%.2f", buffer_pool[i].size / 1024.0 / 1024.0); + line += buf; + } + return line; + } +#endif + void * alloc(size_t size, size_t * actual_size) override { #ifdef DEBUG_sycl_MALLOC int nnz = 0; @@ -1459,6 +1496,10 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(q return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); } +std::unique_ptr ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { + return std::unique_ptr(new ggml_sycl_fattn_kv_buffers(qptr, device)); +} + // TBD pool with virtual memory management // struct ggml_sycl_pool_vmm : public ggml_sycl_pool From 7072bdab9233401c45d27035a4cdbebd1d5bef49 Mon Sep 17 00:00:00 2001 From: scutler-nv Date: Sun, 10 May 2026 02:05:22 -0700 Subject: [PATCH 040/289] internal AllReduce kernel for CUDA provider (llama/22299) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: add internal AllReduce provider for tensor parallelism Introduces a NCCL-free AllReduce implementation for LLAMA_SPLIT_MODE_TENSOR using a single-phase CUDA kernel that pipelines D2H copy, cross-GPU handshake via pinned-memory volatile flags, and the reduction in one kernel launch per GPU. New files: - ggml/src/ggml-cuda/comm.cuh — ggml_cuda_allreduce_provider enum - ggml/src/ggml-cuda/allreduce.cuh — pipeline API declarations - ggml/src/ggml-cuda/allreduce.cu — kernel + pipeline init/dispatch ggml-cuda.cu changes: - ggml_backend_cuda_comm_context gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: add --allreduce flag to select AllReduce provider Adds --allreduce to llama-bench (and via the shared field pattern, consistent with other multi-value flags). Useful for isolating hangs or regressions in tensor-parallel mode: pass --allreduce nccl to force NCCL and bypass the internal provider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: rename --allreduce to --reduction-provider / -rp Co-Authored-By: Claude Sonnet 4.6 via the shared field pattern, consistent with other multi-value flags). Useful for isolating hangs or regressions in tensor-parallel mode: pass --allreduce nccl to force NCCL and bypass the internal provider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: pass WARN/ERROR log messages through in non-verbose mode The null log callback was silently dropping all messages. WARN and ERROR should always be visible since they indicate legitimate issues (e.g. a requested reduction provider not being available). Co-Authored-By: Claude Sonnet 4.6 vider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * cmake: improve NCCL detection for source-tree builds, add static/dynamic switch FindNCCL.cmake now searches the cmake source-build layout used by the Windows NCCL port (cmake/lib/Release for static, cmake/src/Release for dynamic import lib) and also checks src/include for the generated nccl.h header. New option GGML_CUDA_NCCL_STATIC (default OFF) selects static vs dynamic linking and controls which paths and library names are searched. Co-Authored-By: Claude Sonnet 4.6 for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: add AllReduce hang watchdog (GGML_CUDA_AR_WATCHDOG) When compiled with -DGGML_CUDA_AR_WATCHDOG=ON, uses a debug kernel variant that writes per-GPU spin diagnostics to pinned host memory. A host-side blocking poll (cudaEventQuery + volatile reads) detects hangs and logs WARN with the last observed arrival counters and spin counts, controlled by GGML_CUDA_AR_WATCHDOG (ms timeout) and GGML_CUDA_AR_MAX_SPIN (kernel bailout) env vars at runtime. Zero overhead on the production path — all debug code is behind #ifdef. Co-Authored-By: Claude Sonnet 4.6 ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: fix intermittent AllReduce hang on Blackwell PCIe Add __threadfence_system() before the arrival signal write in signal_set to ensure D2H data is globally visible before the peer observes the arrival flag. Without this fence, the peer could enter Phase 3 host reads before the data had fully landed, causing an intermittent deadlock on RTX 5090 (Blackwell, PCIe-only). Also redesign the watchdog from a blocking dispatch-thread poll to a non-blocking background thread, eliminating the ~20ms per-slot latency the old design added. Verified: 30/30 soak test runs clean at ~50 t/s (previously ~1-in-15 hang rate). Co-Authored-By: Claude Sonnet 4.6 - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: fix watchdog shutdown ordering and pipeline_free drain - Stop watchdog thread BEFORE destroying GPU resources (events, streams) to prevent polling destroyed handles → spurious "busy" readings - Add cudaStreamSynchronize in pipeline_free to drain in-flight kernels before freeing pinned host buffers they may still be reading - Sleep-first watchdog polling: no +0ms noise, only logs when a kernel is genuinely stuck past the poll interval - Check wdog_stop in both outer and inner loops so join() returns promptly instead of draining the entire queue - Add Phase 3 breadcrumbs to debug[3] for hang localization Co-Authored-By: Claude Sonnet 4.6 RNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: replace event-based watchdog with per-GPU ring buffer Completely rework the GGML_CUDA_AR_WATCHDOG system: - Replace the shared debug_buf + event-polling + queue design with per-GPU ring buffers in pinned host memory - Kernel writes a debug record only on spin-limit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * fix: normalize line endings to LF (undo Windows CRLF conversion) Five files were inadvertently converted to CRLF by the Windows development environment, causing every line to show as changed in diffs against master. Co-Authored-By: Claude Sonnet 4.6 imit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * .gitattributes: force LF line endings to prevent Windows CRLF conversion Co-Authored-By: Claude Sonnet 4.6 elopment environment, causing every line to show as changed in diffs against master. Co-Authored-By: Claude Sonnet 4.6 imit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: move GGML_CUDA_AR_WATCHDOG from CMake option to local define The watchdog is development-only; a global CMake option is overkill. Move the toggle to a #define at the top of allreduce.cu (set to 0 by default) and remove the option from ggml/CMakeLists.txt and the CUDA CMakeLists.txt add_compile_definitions block. Co-Authored-By: Claude Sonnet 4.6 fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * unify kernel debug paths * use __threadfence_system explicitly (not in ggml_cuda_ar_signal_set) * preferentially use internal reduction for <=2 GPUs * templatize the main kernel to support fp16/bf16 * restore llama-bench.cpp changes * revert CMakeLists changes * remove notes from repo * remove dead warmup code * fix comments * improve reduction provider fallback code * add messages for allreduce fallback * rework reduction provider init to not call ncclCommInitAll if using the internal provider * fix case where a given tensor has not been computed * add chunked mode to the kernel for unlimited vector size * rework a few checks/fallbacks * various small cleanups * allow disabling CUDA reductions completely (falling back to the non-CUDA butterfly mode) * simplify reduction provider selection * minor simplifications * more cleanups/fixes * prototype alternate path for large reductions * chunked version of large reduction path * use bf16 for large reductions * experimental reduction using cudaMemcpyPeerAsync (slightly slower) * revert experimental change * add combined conversion/reduction kernel * add bf16 wire format for single kernel mode * experimental on-stream small reduction kernel * double buffer arrival slots, use token (incrementing) method * double buffer host_buf for small reductions * put in waits for use of host_mem in large reduction case (prevents stomping on in-use memory * remove watchdog code * various cleanups / dead code removal * fix fp16 mode * fix some comments/logging statements * use increasing token scheme for arrival signals * add top-level comment to allreduce.cu * improve top-level comment in allreduce.cu * fix comments in ggml_cuda_ar_kernel * improve event handling for hostmem buffer usage tracking * change ev_pool to fixed 2D array * add chunked memcpy fallback for extra-large reductions (>32 MB) * change thresholds for copy-engine path and bf16 demotion * multi-block kernel test * more fine-tuning for chukn-size, etc. * various fixes for PR review * more PR fixes * fix semantics of all host mappings * require ampere+ * small cleanups * properly use host pointer for src/dst in cudaMemcpy calls * allreduce: lazy-init the internal pipeline on first use A config that lives entirely on NCCL never needs the chunked-kernel pipeline (host_buf, host_large, dev_tmp, streams, events, arrival ring). Defer pipeline creation to the first try_allreduce_internal call using the same std::call_once pattern as ensure_nccl, so those resources stay unallocated when only NCCL is in use. Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: assert n_backends == 2 instead of soft-fallback ar_pipeline_init already requires n_devices == 2 and bails before any AR can get here, so by the time we reach try_allreduce_internal we know we have exactly two backends. Replace the runtime-debug-log fallback with a hard assert. Co-Authored-By: Claude Opus 4.7 (1M context) NCCL is in use. Co-Authored-By: Claude Opus 4.7 (1M context) * rework reduction provider selection. internal/nccl is OS dependent; most fallbacks are removed * remove unneeded Turing arch check (llama.cpp doesn't even compile pre-Turing anyway) * allreduce: ASCII-only comments and ggml_cuda_cast for value conversions Replace non-ASCII characters in comments (em dashes, right arrows) with ASCII equivalents (--, ->) so the source stays in the ggml/upstream norm. In the kernel-side code, replace static_cast/static_cast with ggml_cuda_cast<...> so the BF16 conversions go through the fast __float2bfloat16 / __bfloat162float intrinsics from convert.cuh. Pure pointer and integer casts stay as static_cast. Also drops two stray garbage tokens that snuck in from earlier merges (a duplicated 'return ok; }' tail in allreduce.cu and a leftover '_reg)' fragment in ggml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: use ggml_cuda_memcpy_1 for the chunked-kernel vector copies The chunked kernel's two 16-byte register<->host transfers (Phase 1 store and Phase 3 load) used reinterpret_cast on both sides. Replace with ggml_cuda_memcpy_1, which is the canonical helper for this pattern and emits the same int4 LD/ST under the hood. Conformance passes; 5x reruns of 70b internal pp512 show 1832-1836 t/s, matching the prior matrix value of 1831 t/s -- no perf change as expected. Co-Authored-By: Claude Opus 4.7 (1M context) ok; }' tail in allreduce.cu and a leftover '_reg)' fragment in ggml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: assert cuda_ctx->device matches the pipeline's device Both ggml_cuda_ar_pipeline and ggml_backend_cuda_context carry the device they were created for; if they ever disagree, every cuda call that follows runs on the wrong device. Add GGML_ASSERT at each cuda_ctx retrieval site in the AR path so the misuse fails fast rather than silently corrupting. Also: rename __nv_bfloat16 -> nv_bfloat16 (typedef alias) for consistency with the rest of the file, and tighten one cudaGetLastError check to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: expand one-liner for loops to braced bodies Code-style preference -- match the rest of the file by writing every for loop with the body on its own braced line. Three sites in the copy-engine typed dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) in the AR path so the misuse fails fast rather than silently corrupting. Also: rename __nv_bfloat16 -> nv_bfloat16 (typedef alias) for consistency with the rest of the file, and tighten one cudaGetLastError check to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: rename template parameters Tdst/Twire/Tsrc -> T_dst/T_wire/T_src Code-style preference per PR review -- T_dst/T_wire/T_src is more consistent with surrounding code. Whole-word rename across all 58 sites in allreduce.cu (kernel definitions, internal uses, and comment text). Realigned the parameter columns in three function signatures whose T_src/T_dst lines shifted by 1 char relative to their non-templated neighbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: drop hyphen in 'chunked-kernel' across comments Per PR review feedback -- 'chunked kernel' (no hyphen) reads more naturally in running prose, especially for ESL readers. Pure comment-only change; all 10 occurrences in allreduce.cu updated. Co-Authored-By: Claude Opus 4.7 (1M context) three function signatures whose T_src/T_dst lines shifted by 1 char relative to their non-templated neighbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: use ggml_cuda_get_max_cpy_bytes() instead of hardcoded 16 The chunked kernel hardcoded a 16-byte vector unit; replace with the ggml_cuda_get_max_cpy_bytes() helper that fattn-common.cuh uses for the same purpose, so ELEMS_PER_VEC self-adjusts to the arch's widest single-instruction copy. Perf-neutral on supported targets (Volta+ returns 16). Co-Authored-By: Claude Opus 4.7 (1M context) hbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * ggml-cuda: PR review fixes -- annotate #endif, fix stale comment, assert nbytes alignment Three separate but minor changes from PR #22299 review feedback: 1. Annotate the five GGML_USE_NCCL #endif lines with the matching condition so the pairing is visible without scrolling back. 2. The comment block on ggml_backend_cuda_comm_context claimed NCCL is lazy-initialised; that was true at one point but the dispatch refactor (727b141c0) made both NCCL and the internal pipeline eager. Rewrite the comment to match current behaviour. 3. Assert in ggml_backend_cuda_comm_allreduce_internal that the tensor's byte size is a 16-byte multiple. The chunked-kernel issues full-width vector loads/stores, so this is a precondition; tensor-parallel splits of hidden-dim-multiples satisfy it trivially, but a hard assert turns any caller-side bug into a clear failure rather than UB. Co-Authored-By: Claude Opus 4.7 (1M context) device's new AR records its ev.ker -- otherwise the second device's wait sees the first device's just-recorded event (the in-flight new AR) and creates a circular dependency with the in-kernel peer signal. Two-pass dispatch (all waits, then all launches) avoids this. Bump POOL_SIZE 2 -> 8 (small memory cost, more breathing room for the GPU's view of the event chain) and add a runtime env override for the hybrid kernel chunk size (GGML_CUDA_AR_HYBRID_CHUNK_BYTES) for tuning. One-shot stderr diagnostic at first AR prints the chosen path + sizing. Result on 2x RTX 5090 Linux, 70b ub_sweep: ub=64 (1 MB AR): 913 -> 1036 t/s (+13.5% vs old, +1.8% vs NCCL) ub=128 (2 MB AR): 1056 -> 1181 (+11.9%, +3.7% vs NCCL) ub=256 (4 MB AR): 1212 -> 1424 (+17.5%, +3.5% vs NCCL) Internal now beats NCCL at every size (+1.8% to +15.6%), recovering all ground in the 1-4 MB regime that was previously a 10-12% loss. Co-Authored-By: Claude Opus 4.7 (1M context) * simplify the init logic * address some other PR requests * ggml-cuda: stub internal AllReduce on HIP/MUSA, drop pre-Ampere mention, gate NCCL fallback warning on !HIP The internal AllReduce relies on cudaHostAllocPortable/Mapped, cudaHostGetDevicePointer, and __nanosleep -- none of which the HIP or MUSA shims expose -- so wrap the implementation in !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) and provide nullptr/no-op/false stubs in the #else branch. The dispatcher already treats a null pipeline as init failure and silently falls back to the meta backend's generic AllReduce, so HIP/MUSA builds compile clean and behave correctly without further call-site changes. PR review follow-ups: - drop "or pre-Ampere?" from the internal-init failure warning -- the kernel doesn't require Ampere or newer. - guard the "NCCL not compiled in" fallback warning behind !defined(GGML_USE_HIP); the suggestion to install NCCL only makes sense on NVIDIA builds. Co-Authored-By: Claude Opus 4.7 (1M context) hind, now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: guard __nanosleep on Volta+ and reject pre-Volta devices at init __nanosleep is the only Volta-specific intrinsic in the kernel; wrap it in #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA / NO_DEVICE_CODE so the file still compiles cleanly when targeting older arches (the dispatcher's init check below ensures the kernel is never actually launched on pre-Volta). Add a per-device compute-capability check in pipeline_init that returns nullptr if any device is below sm70. The dispatcher already treats nullptr as init failure and silently falls back to the meta backend's generic AllReduce. Co-Authored-By: Claude Opus 4.7 (1M context) rom the internal-init failure warning -- the kernel doesn't require Ampere or newer. - guard the "NCCL not compiled in" fallback warning behind !defined(GGML_USE_HIP); the suggestion to install NCCL only makes sense on NVIDIA builds. Co-Authored-By: Claude Opus 4.7 (1M context) hind, now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: fix CI -Werror warnings (sign-compare, format, restrict alias, maybe-uninitialized) The CUDA CI builds with -Werror -Wsign-compare -Wformat -Wrestrict -Wmaybe-uninitialized. Address each: - n_devices is size_t; change `int i; i < n_devices` to size_t in the three init loops, and the matching GGML_LOG_INFO format from %d to %zu. - ggml_cuda_ar_kernel was launched with sendbuf == recvbuf (in-place reduction), so the __restrict__ qualifiers on those parameters were technically UB. Drop __restrict__ from sendbuf and recvbuf; an A/B sweep showed <0.6% perf delta (within noise) on Linux. - The buf/src/dst pointer arrays in ggml_cuda_ar_allreduce and the per-iteration arrays in ggml_cuda_ar_allreduce_copy_outer were declared with size GGML_CUDA_MAX_DEVICES but the loop only writes indices [0, n_devices); zero-initialise so the compiler sees the tail elements as defined. Co-Authored-By: Claude Opus 4.7 (1M context) now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * ggml-cuda: drop unused-function warning by guarding try_allreduce_nccl behind GGML_USE_NCCL The only call site (in init_nccl) is already inside #ifdef GGML_USE_NCCL, so the function is unreferenced in non-NCCL builds and trips nvcc's -Werror=unused-function check. Move the guard from inside the function body to around the entire definition. Co-Authored-By: Claude Opus 4.7 (1M context) ce reduction), so the __restrict__ qualifiers on those parameters were technically UB. Drop __restrict__ from sendbuf and recvbuf; an A/B sweep showed <0.6% perf delta (within noise) on Linux. - The buf/src/dst pointer arrays in ggml_cuda_ar_allreduce and the per-iteration arrays in ggml_cuda_ar_allreduce_copy_outer were declared with size GGML_CUDA_MAX_DEVICES but the loop only writes indices [0, n_devices); zero-initialise so the compiler sees the tail elements as defined. Co-Authored-By: Claude Opus 4.7 (1M context) now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-cuda/allreduce.cu | 968 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/allreduce.cuh | 29 + ggml/src/ggml-cuda/ggml-cuda.cu | 265 +++++++-- 3 files changed, 1205 insertions(+), 57 deletions(-) create mode 100644 ggml/src/ggml-cuda/allreduce.cu create mode 100644 ggml/src/ggml-cuda/allreduce.cuh diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu new file mode 100644 index 00000000000..434689abd95 --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -0,0 +1,968 @@ +#include "allreduce.cuh" + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + +#include "convert.cuh" +#include "ggml-impl.h" + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// CUDA AllReduce for tensor-parallel inference across two GPUs. +// +// Provides an in-place sum reduction over matching tensors on two CUDA +// devices in the same process. Used by the tensor-split path alongside +// NCCL; targets setups without NVLink, where data is exchanged between the +// GPUs by staging it through pinned host memory over PCIe. +// +// Two reduction strategies are selected per call by tensor size: +// +// * Chunked kernel path (small reductions): a single CUDA kernel both +// stages data through pinned host memory and performs the local sum. +// Cross-GPU synchronization happens *inside the kernel* (busy-wait on +// a host-memory flag), which keeps launch overhead low for the +// latency-sensitive token-generation case. +// +// * Copy-engine path (large reductions): the transfer is split into +// D2H + H2D cudaMemcpyAsync chunks driven by the GPU's copy engine, +// followed by a small device-side add kernel. Cross-GPU +// synchronization happens *outside the kernel*, via CUDA events +// between streams. This keeps the compute engine free while large +// transfers are in flight, which matters for prefill-sized tensors. +// Reductions larger than the per-call inner cap are processed by an +// outer chunker that issues sequential inner calls. +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Cross-GPU signal mechanism +// +// One int per (slot, rank) pair in pinned host memory. Each AR call writes a +// strictly increasing token (= the AR call number) into its own arrival int. +// The peer spins until its read of the other's arrival int equals the token +// it expects for this call -- a mismatch means the peer hasn't arrived yet. +// Tokens never repeat over realistic call rates (32-bit int wraps in tens of +// days at thousands of ARs/sec), so arrival ints don't need to be reset +// between calls; we initialize once at pipeline init and let the values +// accumulate. +// +// There is exactly one writer (the owning GPU) and one reader (the peer), so +// we don't need atomics. A volatile store paired with __threadfence_system() +// provides the release ordering that makes the D2H writes visible system-wide +// before the arrival token is observed. +// +// atomicAdd_system() requires hostNativeAtomicSupported, which is unavailable +// on PCIe-attached consumer GPUs without NVLink, so the volatile path is the +// portable choice. +// --------------------------------------------------------------------------- + +static __device__ __forceinline__ void ggml_cuda_ar_signal_set(int * p, int token) { + *(volatile int *)p = token; +} +static __device__ __forceinline__ int ggml_cuda_ar_signal_get(const int * p) { + return *(const volatile int *)p; +} + +// Byte spacing between adjacent arrival ints. 64 bytes (one cache line) +// ensures each GPU/block's arrival slot lives on its own line, preventing +// false-sharing stalls on the polling GPU. +static constexpr size_t GGML_CUDA_AR_ARRIVAL_STRIDE = 64; + +// Number of blocks the chunked kernel launches with. Each block stripes a +// disjoint slice of the data and synchronizes through its own arrival-token +// slot so multiple SMs can pump PCIe stores in parallel. +static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; + +// --------------------------------------------------------------------------- +// Chunked kernel AllReduce -- 2 GPUs, supports float, half, and bfloat16. +// +// Both GPUs run this kernel simultaneously on independent streams. sendbuf +// and recvbuf live in T_dst (the caller's tensor type); host_mine / host_other +// carry data in T_wire (the on-wire type, possibly narrower than T_dst -- e.g. +// T_dst=F32 with T_wire=BF16 halves the bytes pushed across PCIe). When +// T_dst == T_wire the casts below are no-ops. +// +// Each GPU runs three phases: +// +// Phase 1 (all threads): cast sendbuf (T_dst) -> T_wire and store as +// single-instruction-width vectors into host_mine. +// __threadfence_system() commits these writes to host +// memory. +// Phase 2 (thread 0): write token to arrival_mine; spin until +// arrival_other == token. +// Phase 3 (all threads): read T_wire vectors from host_other, cast +// each element to T_dst, and sum with the local +// sendbuf value (also rounded through T_wire so that +// both GPUs truncate identically -- this guarantees +// bit-equivalent results across the two devices). +// +// Multi-block: blocks stripe vectors across (gridDim.x * blockDim.x) global +// threads to keep multiple SMs issuing PCIe stores in parallel. Each block +// has its own arrival-token slot (offset by blockIdx.x * ARRIVAL_STRIDE); +// thread 0 of each block signals/spins on that slot independently of other +// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are +// handled only by block 0 to avoid cross-block writes to the same slots. +// --------------------------------------------------------------------------- +template +static __global__ void ggml_cuda_ar_kernel( + const T_dst * sendbuf, + T_dst * recvbuf, + T_wire * __restrict__ host_mine, + const T_wire * __restrict__ host_other, + int count, + int * arrival_mine, + int * arrival_other, + int token) { + + // Vector unit for the wire type, sized to the arch's widest single-instruction + // copy (16 B on Volta+). Each phase-1 iter writes one vector to host memory; + // each phase-3 iter reads one and produces ELEMS_PER_VEC sums. + constexpr int ELEMS_PER_VEC = ggml_cuda_get_max_cpy_bytes() / sizeof(T_wire); + constexpr int ARRIVAL_INTS = (int)(GGML_CUDA_AR_ARRIVAL_STRIDE / sizeof(int)); + + const int tid = threadIdx.x; + const int nt = blockDim.x; + const int bid = blockIdx.x; + const int gtid = bid * nt + tid; + const int gnt = gridDim.x * nt; + const int count_vec = count / ELEMS_PER_VEC; + const int tail = count_vec * ELEMS_PER_VEC; + + // Phase 1: cast sendbuf (T_dst) -> host_mine (T_wire) and store as vectors. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + wire[k] = ggml_cuda_cast(sendbuf[off + k]); + } + ggml_cuda_memcpy_1(&host_mine[off], wire); + } + if (bid == 0 && tid < count - tail) { + host_mine[tail + tid] = ggml_cuda_cast(sendbuf[tail + tid]); + } + } + + // Commit this block's host writes before signalling. + __threadfence_system(); + __syncthreads(); + + // Phase 2: thread 0 of each block signals on its own arrival slot, then + // spins for the matching slot from peer. Per-block tokens mean blocks + // proceed independently -- no inter-block barrier needed. + if (tid == 0) { + int * my_slot = arrival_mine + bid * ARRIVAL_INTS; + const int * other_slot = arrival_other + bid * ARRIVAL_INTS; + + ggml_cuda_ar_signal_set(my_slot, token); + __threadfence_system(); // make our signal visible system-wide + + while (ggml_cuda_ar_signal_get(other_slot) != token) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __nanosleep(100); +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + } + + __syncthreads(); + + // Acquire peer's host_other writes (this block's stripe of them). + __threadfence_system(); + + // Phase 3: read peer's T_wire vector, cast both sides through T_wire for + // bit-equivalence, sum in T_dst precision, and write back to recvbuf. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + ggml_cuda_memcpy_1(wire, &host_other[off]); + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); + recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + } + } + if (bid == 0 && tid < count - tail) { + const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); + recvbuf[tail + tid] = + ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + } + } +} + +// Combined load-convert-add kernel. The peer's contribution arrives as T_src +// (which may be a lower-precision type than T_dst when the BF16 round-trip is +// active). For bit-equivalence between the two GPUs, dst is first rounded +// through T_src's precision via ggml_cuda_cast -- peer already truncated its +// own value the same way before sending -- so both sides perform identical +// arithmetic. When T_dst == T_src the round-trip cast is a no-op. +template +static __global__ void ggml_cuda_ar_add_kernel( + T_dst * __restrict__ dst, + const T_src * __restrict__ src, + int count) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int nt = gridDim.x * blockDim.x; + for (int i = tid; i < count; i += nt) { + const T_src d_low = ggml_cuda_cast(dst[i]); + dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + } +} + +// --------------------------------------------------------------------------- +// Pipeline structure +// --------------------------------------------------------------------------- + +// Number of slots in the event / arrival ring. Two slots is sufficient: +// lockstep guarantees the two GPUs are at most one AR (or chunk) apart, so +// slot[N%2] is always safe to reuse -- peer has already consumed slot[N%2] +// from AR N-2 by the time we get to AR N. acquire_slot's +// cudaEventSynchronize on ev.ker for both devices makes that consumption +// explicit before we overwrite host_buf[slot] for the new AR. +static constexpr int GGML_CUDA_AR_POOL_SIZE = 2; + +// Maximum chunk size (bytes per GPU) handled by one chunked kernel launch. +// Larger tensors are reduced by issuing multiple chunked launches. +static constexpr size_t GGML_CUDA_AR_MAX_BYTES = 1024 * 1024; // 1 MB + +// Copy-engine path: largest tensor accepted on this path; sets host_large / +// dev_tmp allocation size. +static constexpr size_t GGML_CUDA_AR_COPY_MAX_BYTES = 32 * 1024 * 1024; // 32 MB + +// AR wire size at which the copy-engine path takes over from the chunked- +// kernel path. Override via GGML_CUDA_AR_COPY_THRESHOLD. +static constexpr size_t GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT = 1024 * 1024; // 1 MB +// Per-call CE chunk-size heuristic: chunk_bytes = clamp(nbytes / 4, MIN, MAX). +// The /4 keeps ~4 chunks in flight at any moment (good D2H/H2D overlap with +// the peer); the clamps cover the cases where nbytes/4 is too small (per- +// memcpy fixed cost dominates) or too large (chunk-level pipelining stalls). +// Env var GGML_CUDA_AR_COPY_CHUNK_BYTES can override with a fixed value. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN = 512 * 1024; // 512 KB +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX = 2 * 1024 * 1024; // 2 MB +// Absolute floor that an env-var override is allowed to set; this caps the +// per-slot copy-event array. 256 KB -> up to 128 chunks per 32 MB tensor. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN = 256 * 1024; +static constexpr int GGML_CUDA_AR_COPY_MAX_CHUNKS = + static_cast((GGML_CUDA_AR_COPY_MAX_BYTES + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN - 1) / + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + +struct ggml_cuda_ar_event_slot { + cudaEvent_t app = nullptr; // upstream computation complete + cudaEvent_t cpy[GGML_CUDA_AR_COPY_MAX_CHUNKS] = {}; // copy-engine D2H chunks complete + cudaEvent_t h2d = nullptr; // copy-engine H2Ds complete (handoff AR stream -> compute stream) + cudaEvent_t ker = nullptr; // AllReduce kernel complete +}; + +// Mapped pinned host allocation: cudaHostAlloc + cudaHostGetDevicePointer +// in one place, with the host handle preserved for cudaFreeHost. Used where +// the CPU never touches the buffer -- only the device reads/writes via the +// mapped device pointer. Required on systems where cudaDevAttrCanUseHost- +// PointerForRegisteredMem is 0 and the host pointer can't be used as a +// device pointer. +struct ggml_cuda_ar_host_mapping { + uint8_t * host = nullptr; // cudaFreeHost handle; also the H-side ptr for cudaMemcpyAsync + uint8_t * dev = nullptr; // device-side pointer for kernels / cudaMemset + + cudaError_t alloc(size_t bytes) { + cudaError_t rc = cudaHostAlloc(reinterpret_cast(&host), bytes, + cudaHostAllocPortable | cudaHostAllocMapped); + if (rc != cudaSuccess) { + host = nullptr; + return rc; + } + rc = cudaHostGetDevicePointer(reinterpret_cast(&dev), host, 0); + if (rc != cudaSuccess) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + return rc; + } + + void free() { + if (host) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + } +}; + +struct ggml_cuda_ar_pipeline { + int n_devices; + int devices[GGML_CUDA_MAX_DEVICES]; + size_t buf_bytes; // bytes per device in host_buf[] + size_t copy_bytes; // bytes per device in host_large[] / dev_tmp[] + size_t copy_threshold; + size_t copy_chunk_bytes; + size_t bf16_threshold; // tensors >= this size (bytes) are reduced via FP32->BF16 round-trip; 0 disables + uint64_t call_count; + + // Per-device resources. + ggml_cuda_ar_host_mapping host_buf[GGML_CUDA_MAX_DEVICES]; // pinned staging (chunked kernel) + ggml_cuda_ar_host_mapping host_large[GGML_CUDA_MAX_DEVICES]; // pinned staging (copy-engine) + char * dev_tmp[GGML_CUDA_MAX_DEVICES]; // device scratch for copy-engine path + cudaStream_t streams[GGML_CUDA_MAX_DEVICES]; // non-blocking + ggml_cuda_ar_event_slot ev_pool[GGML_CUDA_MAX_DEVICES][GGML_CUDA_AR_POOL_SIZE]; + + // Copy-engine: per-device "I finished reading my peer's host_large" + // event. Indexed by RECORDER device. Recorded same-device on streams[i] + // after stage 2's last H2D from host_large[peer]. Waited cross-device + // by peer's stage-1 stream before the next AR overwrites host_large[peer]. + cudaEvent_t host_large_read_done[GGML_CUDA_MAX_DEVICES]; + bool host_large_read_done_valid; + + // Copy-engine: per-device "my add_kernel is done with dev_tmp" event. + // Recorded on the compute stream after each add_kernel; the AR stream + // waits on it before the next copy_impl's H2D overwrites dev_tmp. Lets us + // single-buffer dev_tmp despite add_kernel running on a separate stream. + cudaEvent_t dev_tmp_kernel_done[GGML_CUDA_MAX_DEVICES]; + bool dev_tmp_kernel_done_valid; + + // Arrival ring: ARRIVAL_STRIDE bytes between adjacent ints. Mapped pinned + // memory; CPU never reads/writes -- only the kernel and cudaMemset. + // Use ggml_cuda_ar_arrival_ptr() to index. + ggml_cuda_ar_host_mapping arrival; +}; + +// Base pointer for the (slot, rank) per-block token block. The kernel adds +// blockIdx.x * (ARRIVAL_STRIDE/sizeof(int)) internally to land on its own slot. +static int * ggml_cuda_ar_arrival_ptr(const ggml_cuda_ar_pipeline * p, int slot, int rank) { + const size_t offset = ((size_t)slot * p->n_devices + rank) * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + return reinterpret_cast(p->arrival.dev + offset); +} + +static uint64_t ggml_cuda_ar_env_u64(const char * name, uint64_t default_value) { + const char * value = getenv(name); + if (value == nullptr || value[0] == '\0') { + return default_value; + } + + char * end = nullptr; + const unsigned long long parsed = strtoull(value, &end, 10); + return end != value ? (uint64_t) parsed : default_value; +} + +struct ggml_cuda_ar_slot_info { + int slot; + int token; +}; + +static ggml_cuda_ar_slot_info ggml_cuda_ar_acquire_slot(ggml_cuda_ar_pipeline * p) { + const int slot = static_cast(p->call_count % GGML_CUDA_AR_POOL_SIZE); + const bool pool_lapped = p->call_count >= GGML_CUDA_AR_POOL_SIZE; + p->call_count++; + + if (pool_lapped) { + for (int i = 0; i < p->n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaEventSynchronize(p->ev_pool[i][slot].ker)); + } + } + + return { slot, (int) p->call_count }; +} + +// Per-AR copy-engine chunk size: env-var override if set, else heuristic +// (clamp(nbytes/4, HEURISTIC_MIN, HEURISTIC_MAX)). +static size_t ggml_cuda_ar_chunk_bytes(const ggml_cuda_ar_pipeline * p, size_t nbytes) { + if (p->copy_chunk_bytes > 0) { + return p->copy_chunk_bytes; + } + return std::min(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX, + std::max(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN, nbytes / 4)); +} + +static void ggml_cuda_ar_wait_for_compute( + ggml_cuda_ar_pipeline * p, ggml_backend_cuda_context * cuda_ctx, int rank, int slot) { + ggml_cuda_ar_event_slot & ev = p->ev_pool[rank][slot]; + CUDA_CHECK(cudaEventRecord(ev.app, cuda_ctx->stream())); + CUDA_CHECK(cudaStreamWaitEvent(p->streams[rank], ev.app)); +} + +// --------------------------------------------------------------------------- +// Init / free +// --------------------------------------------------------------------------- + +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int * devices, size_t n_devices) { + + if (n_devices != 2) { + GGML_LOG_DEBUG("%s: internal AllReduce only supports n_devices=2 (got %zu); " + "falling back\n", __func__, n_devices); + return nullptr; + } + + // The chunked kernel uses __nanosleep, which is sm70+ (Volta+). + for (size_t i = 0; i < n_devices; ++i) { + const int cc = ggml_cuda_info().devices[devices[i]].cc; + if (cc < GGML_CUDA_CC_VOLTA) { + GGML_LOG_DEBUG("%s: internal AllReduce requires compute capability >= %d " + "(device %d has cc=%d); falling back\n", + __func__, GGML_CUDA_CC_VOLTA, devices[i], cc); + return nullptr; + } + } + + auto * p = new ggml_cuda_ar_pipeline{}; + p->n_devices = n_devices; + p->copy_bytes = GGML_CUDA_AR_COPY_MAX_BYTES; + p->copy_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_THRESHOLD", GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT); + // 0 = use the per-call heuristic (default). Non-zero env value forces a + // fixed chunk size for diagnostics, with a floor at COPY_CHUNK_BYTES_MIN. + p->copy_chunk_bytes = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_CHUNK_BYTES", 0); + if (p->copy_chunk_bytes > 0 && p->copy_chunk_bytes < GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN) { + GGML_LOG_WARN("%s: GGML_CUDA_AR_COPY_CHUNK_BYTES=%zu below minimum %zu; clamping\n", + __func__, p->copy_chunk_bytes, GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + p->copy_chunk_bytes = GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN; + } + // Default 1: BF16 round-trip is always on for F32 inputs (any non-zero + // ne). Set GGML_CUDA_AR_BF16_THRESHOLD=0 to disable, or to a larger + // byte threshold to opt out for small tensors. + p->bf16_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_BF16_THRESHOLD", 1); + for (size_t i = 0; i < n_devices; ++i) { + p->devices[i] = devices[i]; + } + + // Per-device streams and event pools. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + + cudaStream_t stream = nullptr; + if (cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaStreamCreateWithFlags failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + p->streams[i] = stream; + + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + bool ok = + cudaEventCreateWithFlags(&p->ev_pool[i][s].app, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].h2d, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].ker, cudaEventDisableTiming) == cudaSuccess; + for (int c = 0; ok && c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + ok = cudaEventCreateWithFlags(&p->ev_pool[i][s].cpy[c], cudaEventDisableTiming) == cudaSuccess; + } + if (!ok) { + GGML_LOG_ERROR("%s: cudaEventCreate failed for device %d slot %d\n", + __func__, p->devices[i], s); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + if (cudaEventCreateWithFlags(&p->host_large_read_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for host_large_read_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaEventCreateWithFlags(&p->dev_tmp_kernel_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for dev_tmp_kernel_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Arrival ring: cache-line padded so each GPU's int is on its own line. + const size_t arrival_bytes = + (size_t)GGML_CUDA_AR_POOL_SIZE * n_devices * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + if (p->arrival.alloc(arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + ggml_cuda_set_device(p->devices[0]); + if (cudaMemset(p->arrival.dev, 0, arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMemset for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + + // Per-device pinned staging buffers -- POOL_SIZE-deep ring so the chunked- + // kernel can write the next slot's data while the peer is still reading + // the previous slot's. Indexed by (slot * buf_bytes) at the call site. + p->buf_bytes = GGML_CUDA_AR_MAX_BYTES; + const size_t host_buf_total = (size_t) GGML_CUDA_AR_POOL_SIZE * p->buf_bytes; + for (size_t i = 0; i < n_devices; ++i) { + if (p->host_buf[i].alloc(host_buf_total) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for staging failed (%zu bytes)\n", + __func__, host_buf_total); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Copy-engine path: pinned host staging + device scratch, sized for the + // largest tensor we accept on this path (GGML_CUDA_AR_COPY_MAX_BYTES). + // dev_tmp is single-buffered; cross-AR safety is enforced by an explicit + // cross-stream wait in copy_impl on the prior AR's add_kernel-done event. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + if (p->host_large[i].alloc(p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for large staging failed (%zu bytes)\n", + __func__, p->copy_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaMalloc(reinterpret_cast(&p->dev_tmp[i]), p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMalloc for copy scratch failed (%zu bytes) on device %d\n", + __func__, p->copy_bytes, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + GGML_LOG_INFO("%s: initialized AllReduce pipeline: %zu GPUs, " + "%zu KB chunked kernel staging + %zu MB copy-engine staging per GPU\n", + __func__, n_devices, p->buf_bytes >> 10, p->copy_bytes >> 20); + + return p; +} + +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * p) { + if (!p) { + return; + } + + // Drain all in-flight kernels before tearing down resources. + for (int i = 0; i < p->n_devices; ++i) { + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamSynchronize(p->streams[i]); + } + } + + for (int i = 0; i < p->n_devices; ++i) { + p->host_buf[i].free(); + p->host_large[i].free(); + if (p->dev_tmp[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaFree(p->dev_tmp[i]); + } + ggml_cuda_set_device(p->devices[i]); + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + if (p->ev_pool[i][s].app) { cudaEventDestroy(p->ev_pool[i][s].app); } + for (int c = 0; c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + if (p->ev_pool[i][s].cpy[c]) { cudaEventDestroy(p->ev_pool[i][s].cpy[c]); } + } + if (p->ev_pool[i][s].h2d) { cudaEventDestroy(p->ev_pool[i][s].h2d); } + if (p->ev_pool[i][s].ker) { cudaEventDestroy(p->ev_pool[i][s].ker); } + } + if (p->host_large_read_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->host_large_read_done[i]); + } + if (p->dev_tmp_kernel_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->dev_tmp_kernel_done[i]); + } + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamDestroy(p->streams[i]); + } + } + p->arrival.free(); + delete p; +} + +// --------------------------------------------------------------------------- +// Dispatch +// --------------------------------------------------------------------------- + +// Asymmetric copy_impl: data sent over PCIe in T_src precision (one element of +// nbytes per ne element); accumulated locally into a T_dst buffer. When +// T_src == T_dst this is the original homogeneous reduction. When they differ +// (e.g. BF16 wire / F32 accumulator) the add kernel rounds dst through T_src +// for bit-equivalence between GPUs and we skip the otherwise-needed +// post-conversion entirely. +template +static bool ggml_cuda_ar_allreduce_copy_impl( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne, + size_t nbytes) { + GGML_ASSERT(p->n_devices == 2); + GGML_ASSERT(nbytes <= p->copy_bytes); + GGML_ASSERT(ne <= std::numeric_limits::max()); + + const size_t chunk_bytes = ggml_cuda_ar_chunk_bytes(p, nbytes); + GGML_ASSERT(chunk_bytes > 0); + + const int slot = ggml_cuda_ar_acquire_slot(p).slot; + const size_t copy_chunks = (nbytes + chunk_bytes - 1) / chunk_bytes; + GGML_ASSERT(copy_chunks <= GGML_CUDA_AR_COPY_MAX_CHUNKS); + + ggml_backend_cuda_context * cuda_ctx[2] = {}; + + // Stage 1: both GPUs copy their local contribution to pinned host memory. + for (int i = 0; i < 2; ++i) { + ggml_cuda_set_device(p->devices[i]); + cuda_ctx[i] = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx[i]->device == p->devices[i]); + + ggml_cuda_ar_wait_for_compute(p, cuda_ctx[i], i, slot); + + // Wait for peer's H2D from our host_large[i] (recorded in the + // previous AR's stage 2) to complete before we overwrite host_large[i]. + // host_large_read_done[peer] = peer finished reading host_large[i]. + // No-op on the first AR -- no prior record exists. + if (p->host_large_read_done_valid) { + const int peer = 1 - i; + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->host_large_read_done[peer])); + } + + if (!compute[i]) { + CUDA_CHECK(cudaMemsetAsync(src_buf[i], 0, nbytes, p->streams[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaMemcpyAsync( + p->host_large[i].host + offset, reinterpret_cast(src_buf[i]) + offset, this_bytes, + cudaMemcpyDeviceToHost, p->streams[i])); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].cpy[c], p->streams[i])); + } + } + + // Stage 2: each GPU waits for each peer D2H chunk, pulls that chunk back to + // local device scratch (dev_tmp), then performs one device-local add over + // the assembled peer tensor. The H2Ds run on the AR stream (copy engine) + // and the add_kernel runs on the caller's compute stream, so the AR stream + // stays pure-copy and avoids an in-stream copy->compute engine switch every + // AR. dev_tmp is single-buffered: the AR stream waits cross-stream on the + // prior AR's add_kernel-done event before overwriting it. + for (int i = 0; i < 2; ++i) { + const int peer = 1 - i; + ggml_cuda_set_device(p->devices[i]); + + // Wait for the previous AR's add_kernel (on the compute stream) to + // finish reading dev_tmp before our H2D overwrites it. No-op on the + // first copy_impl call. + if (p->dev_tmp_kernel_done_valid) { + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->dev_tmp_kernel_done[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->ev_pool[peer][slot].cpy[c])); + CUDA_CHECK(cudaMemcpyAsync( + p->dev_tmp[i] + offset, p->host_large[peer].host + offset, this_bytes, + cudaMemcpyHostToDevice, p->streams[i])); + } + + // Mark our reads of host_large[peer] complete so peer's next AR can + // safely overwrite it. + CUDA_CHECK(cudaEventRecord(p->host_large_read_done[i], p->streams[i])); + + // Hand off from AR stream (copy engine) to compute stream: compute + // stream waits for all H2Ds to finish, then runs the add_kernel. + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].h2d, p->streams[i])); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx[i]->stream(), p->ev_pool[i][slot].h2d)); + + const int block_size = 256; + int n_blocks = (int) ((ne + block_size - 1) / block_size); + if (n_blocks > 1024) { + n_blocks = 1024; + } + ggml_cuda_ar_add_kernel<<stream()>>>( + dst_buf[i], + reinterpret_cast(p->dev_tmp[i]), + (int) ne); + CUDA_CHECK(cudaGetLastError()); + + // Record dev_tmp-released on the compute stream so the next copy_impl + // can wait for the kernel to finish before overwriting dev_tmp. Also + // record AR-done as ev.ker for acquire_slot's pool-wraparound sync. + CUDA_CHECK(cudaEventRecord(p->dev_tmp_kernel_done[i], cuda_ctx[i]->stream())); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, cuda_ctx[i]->stream())); + } + p->host_large_read_done_valid = true; + p->dev_tmp_kernel_done_valid = true; + + return true; +} + +// Outer-level chunker: copy_impl handles up to copy_bytes per call (limited by +// the host_large / dev_tmp allocation size). When the full AR exceeds that, +// slice the tensor into copy_bytes-sized pieces and call copy_impl repeatedly. +// Each slice goes through its own stage 1 -> stage 2 cycle and acquires its own +// slot, so cross-AR fences and pool wraparound work the same way as for any +// other sequence of small ARs. +template +static bool ggml_cuda_ar_allreduce_copy_outer( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne) { + const int64_t outer_max_elems = (int64_t) (p->copy_bytes / sizeof(T_src)); + GGML_ASSERT(outer_max_elems > 0); + + bool ok = true; + for (int64_t outer_start = 0; outer_start < ne && ok; outer_start += outer_max_elems) { + const int64_t outer_ne = std::min(outer_max_elems, ne - outer_start); + const size_t outer_nbytes = (size_t) outer_ne * sizeof(T_src); + + T_src * src[GGML_CUDA_MAX_DEVICES] = {}; + T_dst * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < p->n_devices; ++i) { + src[i] = src_buf[i] + outer_start; + dst[i] = dst_buf[i] + outer_start; + } + ok = ggml_cuda_ar_allreduce_copy_impl( + p, backends, src, dst, compute, outer_ne, outer_nbytes); + } + return ok; +} + +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + ggml_tensor ** tensors) { + GGML_ASSERT(p != nullptr); + + const int n = p->n_devices; + GGML_ASSERT(n == 2); + + const ggml_type input_type = tensors[0]->type; + GGML_ASSERT(input_type == GGML_TYPE_F32 || input_type == GGML_TYPE_F16 || input_type == GGML_TYPE_BF16); + + const int64_t ne = ggml_nelements(tensors[0]); + GGML_ASSERT(ne > 0); + + const size_t input_nbytes = ggml_nbytes(tensors[0]); + + // BF16 round-trip: F32 inputs >= bf16_threshold are converted to BF16 for + // the reduction (chunked or copy-engine), halving on-wire bytes. Matches + // NCCL's behaviour. The pre-conversion zeroes inactive shards so the + // inner paths see them as already-prepared compute tensors. + const bool use_bf16 = + input_type == GGML_TYPE_F32 && + p->bf16_threshold > 0 && + input_nbytes >= p->bf16_threshold; + + const ggml_type kernel_type = use_bf16 ? GGML_TYPE_BF16 : input_type; + const size_t type_size = ggml_type_size(kernel_type); + GGML_ASSERT(p->buf_bytes >= type_size); + const size_t nbytes = (size_t) ne * type_size; + + bool compute_flag[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + compute_flag[i] = (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) != 0; + } + + // Decide between copy-engine and chunked kernel paths based on the working + // type's actual byte count. No upper bound: copy_outer slices reductions + // larger than copy_bytes into copy_bytes-sized pieces. + const bool use_copy_engine = + p->copy_threshold > 0 && + nbytes >= p->copy_threshold; + + // BF16 inactive-shard zeroing: when use_bf16 is on, the combined kernel + // (chunked kernel path) and the combined add kernel (copy_engine path) + // both accumulate into the F32 tensor data directly, so an inactive + // shard's accumulator must start at zero. + if (use_bf16) { + for (int i = 0; i < n; ++i) { + if (!compute_flag[i]) { + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, (size_t) ne * sizeof(float), cuda_ctx->stream())); + } + } + } + + // Pre-convert F32 -> BF16 into bf16_tmp ONLY for the copy_engine + use_bf16 + // path; the chunked kernel path's combined kernel does the conversion + // inline as it writes to host_buf. + ggml_cuda_pool_alloc bf16_tmp[GGML_CUDA_MAX_DEVICES]; + void * copy_src_ptr[GGML_CUDA_MAX_DEVICES] = {}; + + if (use_copy_engine && use_bf16) { + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + for (int i = 0; i < n; ++i) { + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + bf16_tmp[i].pool = &cuda_ctx->pool(); + bf16_tmp[i].alloc(ne); + ggml_cuda_set_device(p->devices[i]); + if (compute_flag[i]) { + to_bf16(tensors[i]->data, bf16_tmp[i].get(), ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } else { + CUDA_CHECK(cudaMemsetAsync(bf16_tmp[i].get(), 0, nbytes, cuda_ctx->stream())); + } + copy_src_ptr[i] = bf16_tmp[i].get(); + } + } + + bool ok = true; + if (use_copy_engine) { + // After up-front BF16 conversion, the tmp buffers already hold the + // (possibly zeroed-for-inactive) data, so the inner path can treat + // every shard as compute. + bool inner_compute[GGML_CUDA_MAX_DEVICES]; + for (int i = 0; i < n; ++i) { + inner_compute[i] = use_bf16 ? true : compute_flag[i]; + } + + // Dispatch into copy_impl with explicit src/dst types. When use_bf16 + // is on, the wire type is BF16 (src = bf16_tmp) and the accumulator + // is F32 (dst = tensors[i]->data); the combined add kernel rounds dst + // through BF16 for bit-equivalence and writes F32 directly, so no + // post-conversion is needed. Otherwise src == dst (same native type). + if (use_bf16) { + GGML_ASSERT(kernel_type == GGML_TYPE_BF16); + nv_bfloat16 * src[GGML_CUDA_MAX_DEVICES] = {}; + float * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + src[i] = static_cast(copy_src_ptr[i]); + dst[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, src, dst, inner_compute, ne); + } else { + switch (kernel_type) { + case GGML_TYPE_F32: { + float * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_BF16: { + nv_bfloat16 * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_F16: { + half * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + default: + GGML_ASSERT(false); + } + } + } else { + // host_buf carries T_wire-typed data; max_chunk_elems is the count that + // fits in one host_buf at the wire size. + const size_t max_chunk_elems = p->buf_bytes / type_size; + const size_t input_type_size = ggml_type_size(input_type); + + // Chunked kernel path runs entirely on the caller's compute stream: + // since AR is a barrier here, same-stream ordering subsumes any + // cross-stream event handshake that the copy-engine path needs, and + // skips the cross-stream scheduling overhead that was hurting the + // small-tensor (tg) latency on the AR-stream variant. Only ev.ker is + // still recorded at end-of-AR for acquire_slot's pool-wraparound check. + for (int64_t chunk_start = 0; chunk_start < ne; chunk_start += (int64_t) max_chunk_elems) { + const size_t remaining_elems = (size_t) (ne - chunk_start); + const size_t chunk_elems = remaining_elems < max_chunk_elems ? remaining_elems : max_chunk_elems; + const size_t chunk_dst_bytes = chunk_elems * input_type_size; + + const auto [slot, token] = ggml_cuda_ar_acquire_slot(p); + const bool last_chunk = chunk_start + (int64_t) chunk_elems == ne; + + for (int i = 0; i < n; ++i) { + const int peer = 1 - i; // valid for n == 2 only + ggml_cuda_set_device(p->devices[i]); + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + cudaStream_t stream = cuda_ctx->stream(); + + char * data = static_cast(tensors[i]->data) + chunk_start * (int64_t) input_type_size; + + // Match NCCL/meta-backend semantics: inactive shards contribute + // zeros. On the BF16 path the F32 tensor data was already + // zeroed up-front (above), so per-chunk zeroing isn't needed. + if (!compute_flag[i] && !use_bf16) { + CUDA_CHECK(cudaMemsetAsync(data, 0, chunk_dst_bytes, stream)); + } + +#define LAUNCH_AR_KERNEL(T_dst, T_wire) \ + ggml_cuda_ar_kernel<<>>( \ + reinterpret_cast(data), \ + reinterpret_cast(data), \ + reinterpret_cast(p->host_buf[i].dev + (size_t) slot * p->buf_bytes), \ + reinterpret_cast(p->host_buf[peer].dev + (size_t) slot * p->buf_bytes), \ + static_cast(chunk_elems), \ + ggml_cuda_ar_arrival_ptr(p, slot, i), \ + ggml_cuda_ar_arrival_ptr(p, slot, peer), \ + token) + + if (use_bf16) { + GGML_ASSERT(input_type == GGML_TYPE_F32); + LAUNCH_AR_KERNEL(float, nv_bfloat16); + } else { + switch (input_type) { + case GGML_TYPE_F32: LAUNCH_AR_KERNEL(float, float); break; + case GGML_TYPE_F16: LAUNCH_AR_KERNEL(half, half); break; + case GGML_TYPE_BF16: LAUNCH_AR_KERNEL(nv_bfloat16, nv_bfloat16); break; + default: GGML_ASSERT(false); + } + } + +#undef LAUNCH_AR_KERNEL + CUDA_CHECK(cudaGetLastError()); + + if (last_chunk) { + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, stream)); + } + } + } + } + + return ok; +} + +#else // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) + +// HIP and MUSA lack the host-mapped pinned-memory APIs (cudaHostAllocPortable +// / cudaHostAllocMapped / cudaHostGetDevicePointer) and __nanosleep that this +// implementation relies on, so the internal AllReduce is a CUDA-only feature. +// The dispatcher in ggml-cuda.cu treats a nullptr pipeline as "init failed" +// and silently falls back to the meta backend's generic AllReduce. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int *, size_t) { + return nullptr; +} +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline *) { +} +bool ggml_cuda_ar_allreduce(ggml_cuda_ar_pipeline *, ggml_backend_t *, ggml_tensor **) { + return false; +} + +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/allreduce.cuh b/ggml/src/ggml-cuda/allreduce.cuh new file mode 100644 index 00000000000..0f2c9518d5d --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cuh @@ -0,0 +1,29 @@ +#pragma once + +#include "common.cuh" +#include "ggml-backend-impl.h" + +#include + +// Opaque pipeline context -- owns all pinned buffers, streams, and events. +struct ggml_cuda_ar_pipeline; + +// Allocate a pipeline for n_devices GPUs. +// devices[] holds the CUDA device IDs in rank order. +// Returns nullptr on allocation failure. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init( + const int * devices, size_t n_devices); + +// Release all resources owned by the pipeline. +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * pipeline); + +// Execute an in-place AllReduce (sum) across tensors[0..n_devices-1]. +// tensors[i] must live on the device managed by backends[i] and be +// contiguous F32, F16, or BF16. +// Preconditions are checked by the CUDA comm dispatcher before calling this. +// Returns true once the reduction work has been enqueued successfully. +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * pipeline, + ggml_backend_t * backends, + ggml_tensor ** tensors); + diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4df1b930882..b92a208705d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cuda/allreduce.cuh" #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" #include "ggml-cuda/add-id.cuh" @@ -86,6 +87,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +#define GGML_LOG_WARN_ONCE(str) \ + { static std::once_flag warn_flag; std::call_once(warn_flag, []() { GGML_LOG_WARN(str); }); } + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails @@ -1139,70 +1143,46 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; -#ifdef GGML_USE_NCCL +// Communication context for multi-GPU AllReduce during tensor parallelism. +// +// Created once per meta backend instance. Resources for the selected mode +// (NCCL communicators or the internal AllReduce pipeline) are initialised +// eagerly during comm_init so any init failure surfaces at startup rather +// than mid-run. struct ggml_backend_cuda_comm_context { + using try_allreduce_fn = bool(*)(ggml_backend_cuda_comm_context *, struct ggml_tensor **); + std::vector backends; - std::vector comms; + std::vector dev_ids; - ~ggml_backend_cuda_comm_context() { - for (ncclComm_t comm : comms) { - NCCL_CHECK(ncclCommDestroy(comm)); - } - } -}; -#endif // GGML_USE_NCCL + // Set by the init chain (comm_init_{nccl, internal, none}) to one of + // try_allreduce_{nccl, internal, butterfly}. nccl needs `comms`, + // internal needs `ar_pipeline`, butterfly needs nothing. Per-call + // failures return false; the meta backend's generic implementation then + // handles that call. + try_allreduce_fn try_allreduce = nullptr; + + ggml_cuda_ar_pipeline * ar_pipeline = nullptr; -static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { #ifdef GGML_USE_NCCL - if (comm_ctx_v == nullptr) { - return; - } - ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; - delete comm_ctx; -#else - GGML_UNUSED(comm_ctx_v); + std::vector comms; #endif // GGML_USE_NCCL -} -static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { + ~ggml_backend_cuda_comm_context() { #ifdef GGML_USE_NCCL - for (size_t i = 0; i < n_backends; i++) { - if (!ggml_backend_is_cuda(backends[i])) { - return nullptr; + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); } - } - ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context; - std::vector dev_ids; - ret->backends.reserve(n_backends); - dev_ids.reserve(n_backends); - for (size_t i = 0; i < n_backends; i++) { - ret->backends.push_back(backends[i]); - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - dev_ids.push_back(cuda_ctx->device); - } - - ret->comms.resize(n_backends); - NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data())); - return ret; -#else - // If NCCL is installed it is used by default for optimal performance. - // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. - // RCCL is disabled by default, users are explicitly opting in. - // Therefore print no warning for RCCL. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - static bool warning_printed = false; - if (!warning_printed) { - GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); - warning_printed = true; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - GGML_UNUSED_VARS(backends, n_backends); - return nullptr; #endif // GGML_USE_NCCL -} + ggml_cuda_ar_pipeline_free(ar_pipeline); + } +}; -static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { #ifdef GGML_USE_NCCL +// AllReduce via NCCL. Reduces as FP32 for small tensors and BF16 for large +// tensors (bandwidth-bound), then converts back to FP32. +static bool ggml_backend_cuda_comm_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { const int64_t ne = ggml_nelements(tensors[0]); // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 // This then causes a crash in this function @@ -1210,8 +1190,6 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg return true; } - GGML_ASSERT(comm_ctx_v != nullptr); - ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; const size_t n_backends = comm_ctx->backends.size(); for (size_t i = 0; i < n_backends; ++i) { @@ -1236,7 +1214,6 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); - return true; } @@ -1275,10 +1252,184 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg } return true; -#else - GGML_UNUSED_VARS(comm_ctx_v, tensors); +} +#endif // GGML_USE_NCCL + +// Run the internal AR pipeline. Returns false on unsupported / failed input +// -- the caller decides whether to abort (env-forced) or fall back silently. +static bool ggml_backend_cuda_comm_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + GGML_ASSERT(comm_ctx->ar_pipeline != nullptr); + + const size_t n_backends = comm_ctx->backends.size(); + GGML_ASSERT(n_backends == 2); + GGML_ASSERT(tensors[0] != nullptr); + + const int64_t ne = ggml_nelements(tensors[0]); + const ggml_type type = tensors[0]->type; + + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16 && type != GGML_TYPE_BF16) { + GGML_LOG_DEBUG("%s: internal unsupported: type=%d\n", __func__, (int) type); + return false; + } + + if (ne == 0) { + return true; + } + + for (size_t i = 0; i < n_backends; ++i) { + if (tensors[i] == nullptr) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] is null\n", __func__, i); + return false; + } + if (ggml_nelements(tensors[i]) != ne || tensors[i]->type != type) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] ne=%" PRId64 " type=%d expected ne=%" PRId64 " type=%d\n", + __func__, i, ggml_nelements(tensors[i]), (int) tensors[i]->type, ne, (int) type); + return false; + } + if (!ggml_is_contiguously_allocated(tensors[i])) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] is not contiguously allocated: ne=%" PRId64 " nbytes=%zu packed=%zu type=%d\n", + __func__, i, ne, ggml_nbytes(tensors[i]), + (size_t) ne * ggml_type_size(type) / ggml_blck_size(type), (int) type); + return false; + } + if (((uintptr_t) tensors[i]->data & 0xF) != 0) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] data pointer is not 16-byte aligned: %p type=%d ne=%" PRId64 "\n", + __func__, i, tensors[i]->data, (int) type, ne); + return false; + } + GGML_ASSERT((ggml_nbytes(tensors[i]) & 0xF) == 0); + } + + return ggml_cuda_ar_allreduce(comm_ctx->ar_pipeline, comm_ctx->backends.data(), tensors); +} + +// --------------------------------------------------------------------------- +// Per-call dispatch -- three variants, one per backend. Each is set as +// comm_ctx->try_allreduce by the matching init step. Per-call failure +// returns false; the meta backend's generic implementation handles that call. +// --------------------------------------------------------------------------- + +#ifdef GGML_USE_NCCL +static bool ggml_backend_cuda_comm_try_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_nccl(comm_ctx, tensors); +} +#endif // GGML_USE_NCCL + +static bool ggml_backend_cuda_comm_try_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_internal(comm_ctx, tensors); +} + +static bool ggml_backend_cuda_comm_try_allreduce_butterfly( + ggml_backend_cuda_comm_context *, struct ggml_tensor **) { return false; +} + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { + if (comm_ctx_v == nullptr) { + return; + } + delete static_cast(comm_ctx_v); +} + +// --------------------------------------------------------------------------- +// Init -- chained nccl -> internal -> none. Each step tries to bring up its +// resource; on failure it warns and recurses into the next step. +// --------------------------------------------------------------------------- +static void ggml_backend_cuda_comm_init_none(ggml_backend_cuda_comm_context * ret) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_butterfly; +} + +static void ggml_backend_cuda_comm_init_internal(ggml_backend_cuda_comm_context * ret) { + ret->ar_pipeline = ggml_cuda_ar_pipeline_init(ret->dev_ids.data(), ret->dev_ids.size()); + if (ret->ar_pipeline) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_internal; + return; + } + + // Clear sticky CUDA error from the failed init. + (void) cudaGetLastError(); + GGML_LOG_WARN("internal AllReduce init failed (n_devices != 2?); " + "falling back to meta-backend butterfly\n"); + ggml_backend_cuda_comm_init_none(ret); +} + +static void ggml_backend_cuda_comm_init_nccl(ggml_backend_cuda_comm_context * ret) { +#ifdef GGML_USE_NCCL + const size_t n = ret->dev_ids.size(); + ret->comms.resize(n); + ncclResult_t rc = ncclCommInitAll(ret->comms.data(), (int) n, ret->dev_ids.data()); + if (rc == ncclSuccess) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_nccl; + return; + } + + ret->comms.clear(); + GGML_LOG_WARN("NCCL init failed (%s); falling back to internal AllReduce\n", + ncclGetErrorString(rc)); +#else // GGML_USE_NCCL +#ifndef GGML_USE_HIP + GGML_LOG_WARN("NCCL not compiled in; falling back to internal AllReduce. " + "Recompile with -DGGML_CUDA_NCCL=ON for best multi-GPU performance.\n"); +#endif // !GGML_USE_HIP #endif // GGML_USE_NCCL + + ggml_backend_cuda_comm_init_internal(ret); +} + +// Top-level init. Picks one of the three init paths based on +// GGML_CUDA_ALLREDUCE (or the platform default) and lets the chain handle +// any fallback. Unrecognised env values warn and fall through to the +// platform default. +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + + auto * ret = new ggml_backend_cuda_comm_context; + ret->backends.assign(backends, backends + n_backends); + ret->dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->dev_ids.push_back(static_cast(backends[i]->context)->device); + } + + const char * env = getenv("GGML_CUDA_ALLREDUCE"); + if (!env) { + // Platform default: Linux uses NCCL, otherwise (generally Windows) internal +#if defined(__linux__) + ggml_backend_cuda_comm_init_nccl(ret); +#else + ggml_backend_cuda_comm_init_internal(ret); +#endif // defined(__linux__) + } else { + std::string env_str(env); + if (env_str == "nccl") { + ggml_backend_cuda_comm_init_nccl(ret); + } else if (env_str == "internal") { + ggml_backend_cuda_comm_init_internal(ret); + } else if (env_str == "none") { + ggml_backend_cuda_comm_init_none(ret); + } else { + GGML_LOG_WARN("unknown GGML_CUDA_ALLREDUCE value: %s\n", env); + ggml_backend_cuda_comm_init_none(ret); + } + } + + return ret; +} + +// Top-level dispatch -- calls the function pointer chosen by comm_init. +// Returns false to let the meta-backend's butterfly run. +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { + if (comm_ctx_v == nullptr) { + return false; + } + auto * comm_ctx = static_cast(comm_ctx_v); + return comm_ctx->try_allreduce(comm_ctx, tensors); } ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { From cf6e65bc594b4bc2648d2b3f8e98ffa42c01034e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 16:57:19 +0300 Subject: [PATCH 041/289] ggml : bump version to 0.11.1 (ggml/1484) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8dd4d64063f..672b37dffc3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 4730e765525b718133b5d63d664499ba33b7cd5a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 17:27:59 +0300 Subject: [PATCH 042/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 812e721a8c5..15685a0718f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -19eac6f0edaf285506eb6228d31bb9caeda9aba1 +628249b398293fc8d2fa81a449ae2920a02c6523 From 54ecc9dba43ccf99dcbcb6e1ae3396806e19bf4b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 17:34:06 +0300 Subject: [PATCH 043/289] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 1 + examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-context.cpp | 347 +- examples/talk-llama/llama-context.h | 19 + examples/talk-llama/llama-graph.cpp | 7 +- examples/talk-llama/llama-hparams.h | 2 + examples/talk-llama/llama-io.cpp | 9 +- examples/talk-llama/llama-io.h | 6 +- examples/talk-llama/llama-kv-cache.cpp | 54 +- .../talk-llama/llama-memory-recurrent.cpp | 42 +- examples/talk-llama/llama-model-saver.cpp | 1 + examples/talk-llama/llama-model.cpp | 9207 ++--------------- examples/talk-llama/llama-model.h | 89 +- examples/talk-llama/llama-quant.cpp | 23 +- examples/talk-llama/llama-vocab.cpp | 13 + examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.cpp | 246 +- examples/talk-llama/llama.h | 3 + examples/talk-llama/models/afmoe.cpp | 108 +- examples/talk-llama/models/apertus.cpp | 58 +- examples/talk-llama/models/arcee.cpp | 47 +- examples/talk-llama/models/arctic.cpp | 55 +- examples/talk-llama/models/arwkv7.cpp | 118 +- examples/talk-llama/models/baichuan.cpp | 45 +- examples/talk-llama/models/bailingmoe.cpp | 61 +- examples/talk-llama/models/bailingmoe2.cpp | 96 +- examples/talk-llama/models/bert.cpp | 79 +- examples/talk-llama/models/bitnet.cpp | 49 +- examples/talk-llama/models/bloom.cpp | 64 +- examples/talk-llama/models/chameleon.cpp | 52 +- examples/talk-llama/models/chatglm.cpp | 55 +- examples/talk-llama/models/codeshell.cpp | 51 +- examples/talk-llama/models/cogvlm.cpp | 51 +- .../models/{cohere2-iswa.cpp => cohere2.cpp} | 49 +- examples/talk-llama/models/command-r.cpp | 42 +- examples/talk-llama/models/dbrx.cpp | 46 +- examples/talk-llama/models/deci.cpp | 78 +- examples/talk-llama/models/deepseek.cpp | 73 +- examples/talk-llama/models/deepseek2.cpp | 145 +- examples/talk-llama/models/deepseek2ocr.cpp | 82 + examples/talk-llama/models/dots1.cpp | 72 +- examples/talk-llama/models/dream.cpp | 50 +- examples/talk-llama/models/ernie4-5-moe.cpp | 6 +- examples/talk-llama/models/ernie4-5.cpp | 75 +- examples/talk-llama/models/eurobert.cpp | 37 +- examples/talk-llama/models/exaone-moe.cpp | 113 +- examples/talk-llama/models/exaone.cpp | 45 +- examples/talk-llama/models/exaone4.cpp | 70 +- examples/talk-llama/models/falcon-h1.cpp | 111 +- examples/talk-llama/models/falcon.cpp | 49 +- .../talk-llama/models/gemma-embedding.cpp | 74 +- examples/talk-llama/models/gemma.cpp | 40 +- .../models/{gemma2-iswa.cpp => gemma2.cpp} | 61 +- examples/talk-llama/models/gemma3.cpp | 86 +- .../models/{gemma3n-iswa.cpp => gemma3n.cpp} | 99 +- .../models/{gemma4-iswa.cpp => gemma4.cpp} | 149 +- examples/talk-llama/models/glm-dsa.cpp | 155 + examples/talk-llama/models/glm4-moe.cpp | 135 +- examples/talk-llama/models/glm4.cpp | 74 +- examples/talk-llama/models/gpt2.cpp | 56 +- examples/talk-llama/models/gptneox.cpp | 85 +- examples/talk-llama/models/granite-hybrid.cpp | 137 +- examples/talk-llama/models/granite-moe.cpp | 89 + examples/talk-llama/models/granite.cpp | 93 +- examples/talk-llama/models/grok.cpp | 85 +- examples/talk-llama/models/grovemoe.cpp | 66 +- examples/talk-llama/models/hunyuan-dense.cpp | 132 +- examples/talk-llama/models/hunyuan-moe.cpp | 55 +- examples/talk-llama/models/hunyuan-vl.cpp | 189 + examples/talk-llama/models/internlm2.cpp | 39 +- examples/talk-llama/models/jais.cpp | 54 +- examples/talk-llama/models/jais2.cpp | 57 +- examples/talk-llama/models/jamba.cpp | 107 +- examples/talk-llama/models/jina-bert-v2.cpp | 66 + examples/talk-llama/models/jina-bert-v3.cpp | 69 + examples/talk-llama/models/kimi-linear.cpp | 172 +- examples/talk-llama/models/lfm2.cpp | 92 +- examples/talk-llama/models/lfm2moe.cpp | 85 + examples/talk-llama/models/llada-moe.cpp | 52 +- examples/talk-llama/models/llada.cpp | 68 +- examples/talk-llama/models/llama-embed.cpp | 6 + examples/talk-llama/models/llama.cpp | 101 +- examples/talk-llama/models/llama4.cpp | 108 +- examples/talk-llama/models/maincoder.cpp | 45 +- examples/talk-llama/models/mamba.cpp | 87 +- examples/talk-llama/models/mamba2.cpp | 87 + examples/talk-llama/models/mimo2-iswa.cpp | 129 - examples/talk-llama/models/mimo2.cpp | 240 + examples/talk-llama/models/minicpm.cpp | 89 + examples/talk-llama/models/minicpm3.cpp | 62 +- examples/talk-llama/models/minimax-m2.cpp | 46 +- examples/talk-llama/models/mistral3.cpp | 92 +- examples/talk-llama/models/mistral4.cpp | 6 + examples/talk-llama/models/models.h | 1866 +++- examples/talk-llama/models/modern-bert.cpp | 65 +- examples/talk-llama/models/mpt.cpp | 66 +- examples/talk-llama/models/nemotron-h-moe.cpp | 6 + examples/talk-llama/models/nemotron-h.cpp | 127 +- examples/talk-llama/models/nemotron.cpp | 48 +- examples/talk-llama/models/neo-bert.cpp | 42 +- examples/talk-llama/models/nomic-bert-moe.cpp | 72 + examples/talk-llama/models/nomic-bert.cpp | 72 + examples/talk-llama/models/olmo.cpp | 42 +- examples/talk-llama/models/olmo2.cpp | 67 +- examples/talk-llama/models/olmoe.cpp | 51 +- .../{openai-moe-iswa.cpp => openai-moe.cpp} | 63 +- examples/talk-llama/models/openelm.cpp | 49 +- examples/talk-llama/models/orion.cpp | 42 +- examples/talk-llama/models/paddleocr.cpp | 6 +- .../{pangu-embedded.cpp => pangu-embed.cpp} | 56 +- examples/talk-llama/models/phi2.cpp | 46 +- examples/talk-llama/models/phi3.cpp | 70 +- examples/talk-llama/models/phimoe.cpp | 55 + examples/talk-llama/models/plamo.cpp | 38 +- examples/talk-llama/models/plamo2.cpp | 109 +- examples/talk-llama/models/plamo3.cpp | 73 +- examples/talk-llama/models/plm.cpp | 46 +- examples/talk-llama/models/qwen.cpp | 42 +- examples/talk-llama/models/qwen2.cpp | 51 +- examples/talk-llama/models/qwen2moe.cpp | 63 +- examples/talk-llama/models/qwen2vl.cpp | 41 +- examples/talk-llama/models/qwen3.cpp | 51 +- examples/talk-llama/models/qwen35.cpp | 102 +- examples/talk-llama/models/qwen35moe.cpp | 115 +- examples/talk-llama/models/qwen3moe.cpp | 61 +- examples/talk-llama/models/qwen3next.cpp | 119 +- examples/talk-llama/models/qwen3vl.cpp | 52 +- .../{qwen3vl-moe.cpp => qwen3vlmoe.cpp} | 63 +- examples/talk-llama/models/refact.cpp | 77 +- examples/talk-llama/models/rnd1.cpp | 62 +- examples/talk-llama/models/rwkv6.cpp | 93 +- examples/talk-llama/models/rwkv6qwen2.cpp | 83 +- examples/talk-llama/models/rwkv7.cpp | 123 +- examples/talk-llama/models/seed-oss.cpp | 47 +- examples/talk-llama/models/smallthinker.cpp | 79 +- examples/talk-llama/models/smollm3.cpp | 45 +- examples/talk-llama/models/stablelm.cpp | 50 +- examples/talk-llama/models/starcoder.cpp | 58 +- examples/talk-llama/models/starcoder2.cpp | 57 +- .../models/{step35-iswa.cpp => step35.cpp} | 104 +- examples/talk-llama/models/t5.cpp | 122 +- examples/talk-llama/models/t5encoder.cpp | 43 +- .../talk-llama/models/wavtokenizer-dec.cpp | 117 +- examples/talk-llama/models/xverse.cpp | 39 +- 144 files changed, 12061 insertions(+), 9097 deletions(-) rename examples/talk-llama/models/{cohere2-iswa.cpp => cohere2.cpp} (60%) create mode 100644 examples/talk-llama/models/deepseek2ocr.cpp rename examples/talk-llama/models/{gemma2-iswa.cpp => gemma2.cpp} (53%) rename examples/talk-llama/models/{gemma3n-iswa.cpp => gemma3n.cpp} (76%) rename examples/talk-llama/models/{gemma4-iswa.cpp => gemma4.cpp} (62%) create mode 100644 examples/talk-llama/models/glm-dsa.cpp create mode 100644 examples/talk-llama/models/granite-moe.cpp create mode 100644 examples/talk-llama/models/hunyuan-vl.cpp create mode 100644 examples/talk-llama/models/jina-bert-v2.cpp create mode 100644 examples/talk-llama/models/jina-bert-v3.cpp create mode 100644 examples/talk-llama/models/lfm2moe.cpp create mode 100644 examples/talk-llama/models/llama-embed.cpp create mode 100644 examples/talk-llama/models/mamba2.cpp delete mode 100644 examples/talk-llama/models/mimo2-iswa.cpp create mode 100644 examples/talk-llama/models/mimo2.cpp create mode 100644 examples/talk-llama/models/minicpm.cpp create mode 100644 examples/talk-llama/models/mistral4.cpp create mode 100644 examples/talk-llama/models/nemotron-h-moe.cpp create mode 100644 examples/talk-llama/models/nomic-bert-moe.cpp create mode 100644 examples/talk-llama/models/nomic-bert.cpp rename examples/talk-llama/models/{openai-moe-iswa.cpp => openai-moe.cpp} (51%) rename examples/talk-llama/models/{pangu-embedded.cpp => pangu-embed.cpp} (53%) create mode 100644 examples/talk-llama/models/phimoe.cpp rename examples/talk-llama/models/{qwen3vl-moe.cpp => qwen3vlmoe.cpp} (57%) rename examples/talk-llama/models/{step35-iswa.cpp => step35.cpp} (52%) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 633a66fc665..59dde99e362 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -232,6 +232,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_VALUE_SCALE, "%s.attention.value_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 8f335f5c7b3..e37d548c98e 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -236,6 +236,7 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_VALUE_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 8126249e143..71a59395eb2 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2230,13 +2230,17 @@ llm_graph_cb llama_context::graph_get_cb() const { class llama_io_write_dummy : public llama_io_write_i { public: - llama_io_write_dummy() = default; + llama_io_write_dummy(bool skip_tensors) : skip_tensors(skip_tensors) {} void write(const void * /* src */, size_t size) override { size_written += size; } - void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + void write_tensor(ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + if (skip_tensors) { + return; + } + size_written += size; } @@ -2245,14 +2249,23 @@ class llama_io_write_dummy : public llama_io_write_i { } private: + const bool skip_tensors; + size_t size_written = 0; }; -class llama_io_write_buffer : public llama_io_write_i { +class llama_io_write_host : public llama_io_write_i { public: - llama_io_write_buffer( + llama_io_write_host( uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + ~llama_io_write_host() { + // TODO: add backend support to batch tensor_get? or some other way to speed this up + for (const auto & winfo : winfos) { + ggml_backend_tensor_get(winfo.tensor, winfo.ptr, winfo.offset, winfo.size); + } + } + void write(const void * src, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); @@ -2263,11 +2276,14 @@ class llama_io_write_buffer : public llama_io_write_i { buf_size -= size; } - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } - ggml_backend_tensor_get(tensor, ptr, offset, size); + + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + ptr += size; size_written += size; buf_size -= size; @@ -2281,25 +2297,48 @@ class llama_io_write_buffer : public llama_io_write_i { uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector winfos; }; -class llama_io_read_buffer : public llama_io_read_i { +class llama_io_read_host : public llama_io_read_i { public: - llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + llama_io_read_host(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; + ~llama_io_read_host() { + // flush the reads + for (const auto & rinfo : rinfos) { + ggml_backend_tensor_set(rinfo.tensor, rinfo.ptr, rinfo.offset, rinfo.size); + } + } + + void read(void * dst, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } + memcpy(dst, ptr, size); ptr += size; size_read += size; buf_size -= size; - return base_ptr; } - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + + ptr += size; + size_read += size; + buf_size -= size; } size_t n_bytes() override { @@ -2310,6 +2349,14 @@ class llama_io_read_buffer : public llama_io_read_i { const uint8_t * ptr; size_t buf_size = 0; size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector rinfos; }; class llama_io_write_file : public llama_io_write_i { @@ -2321,7 +2368,7 @@ class llama_io_write_file : public llama_io_write_i { size_written += size; } - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); write(temp_buffer.data(), temp_buffer.size()); @@ -2341,15 +2388,15 @@ class llama_io_read_file : public llama_io_read_i { public: llama_io_read_file(llama_file * f) : file(f) {} - void read_to(void * dst, size_t size) override { + void read(void * dst, size_t size) override { file->read_raw(dst, size); size_read += size; } - const uint8_t * read(size_t size) override { + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); + read(temp_buffer.data(), size); + ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size); } size_t n_bytes() override { @@ -2362,8 +2409,212 @@ class llama_io_read_file : public llama_io_read_i { std::vector temp_buffer; }; +class llama_io_write_device : public llama_io_write_i { +public: + llama_io_write_device(uint8_t * p, size_t len, llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } + + ~llama_io_write_device() { + llama_memory_buffers mbufs_new; + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += winfo.size; + } + + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ 2*mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + mbuf.ctx.reset(ggml_init(params)); + + mbuf.org.reserve(mbuf.n_tensors); + mbuf.cpy.reserve(mbuf.n_tensors); + } + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + const int64_t n = winfo.size/ggml_element_size(winfo.tensor); + + auto & mbuf = mbufs_new[buft]; + + mbuf.org.push_back(ggml_view_1d (mbuf.ctx.get(), winfo.tensor, n, winfo.offset)); + mbuf.cpy.push_back(ggml_new_tensor_1d(mbuf.ctx.get(), winfo.tensor->type, n)); + } + + for (auto & [buft, mbuf] : mbufs_new) { + auto & mbuf_cur = mbufs[buft]; + + bool need_alloc = false; + + need_alloc = need_alloc || (!mbuf_cur.buf); + need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size()); + need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size); + + if (!need_alloc) { + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + auto * org0 = mbuf_cur.org[i]; + auto * org1 = mbuf.org[i]; + + if (!ggml_are_same_shape(org0, org1)) { + need_alloc = true; + break; + } + + if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) { + need_alloc = true; + break; + } + } + } + + if (need_alloc) { + mbuf_cur = std::move(mbuf); + + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.org[i], mbuf_cur.cpy[i]); + } + } + } + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + } + + size_t n_bytes() override { + return size_written; + } + +private: + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector winfos; + + llama_memory_buffers & mbufs; +}; + +class llama_io_read_device : public llama_io_read_i { +public: + llama_io_read_device(const uint8_t * p, size_t len, const llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } + + ~llama_io_read_device() { + llama_memory_buffers mbufs_new; + + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += rinfo.size; + } + + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + mbuf.ctx.reset(ggml_init(params)); + + mbuf.org.reserve(mbuf.n_tensors); + } + + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); + + const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor); + + auto & mbuf = mbufs_new[buft]; + + mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); + + auto & view = mbuf.org.back(); + view->buffer = rinfo.tensor->buffer; + } + + for (auto & [buft, mbuf] : mbufs_new) { + const auto & mbuf_cur = mbufs.at(buft); + + if (!mbuf_cur.buf || mbuf_cur.n_tensors != mbuf.n_tensors || mbuf_cur.total_size != mbuf.total_size) { + GGML_ABORT("%s: memory buffer mismatch\n", __func__); + } + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]); + } + } + + GGML_ASSERT(buf_size == 0); + } + + void read(void * dst, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(dst, ptr, size); + ptr += size; + size_read += size; + buf_size -= size; + } + + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + } + + size_t n_bytes() override { + return size_read; + } + +private: + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector rinfos; + + const llama_memory_buffers & mbufs; +}; + size_t llama_context::state_get_size() { - llama_io_write_dummy io; + llama_io_write_dummy io(false); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2373,7 +2624,7 @@ size_t llama_context::state_get_size() { } size_t llama_context::state_get_data(uint8_t * dst, size_t size) { - llama_io_write_buffer io(dst, size); + llama_io_write_host io(dst, size); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2383,7 +2634,7 @@ size_t llama_context::state_get_data(uint8_t * dst, size_t size) { } size_t llama_context::state_set_data(const uint8_t * src, size_t size) { - llama_io_read_buffer io(src, size); + llama_io_read_host io(src, size); try { return state_read_data(io); } catch (const std::exception & err) { @@ -2392,9 +2643,14 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } +static constexpr uint32_t io_magic = 0xaf143cd8; + size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { - llama_io_write_dummy io; + llama_io_write_dummy io(flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); try { + io.write(&io_magic, sizeof(io_magic)); + io.write(&seq_id, sizeof(seq_id)); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); @@ -2403,9 +2659,18 @@ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_fl } size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { - llama_io_write_buffer io(dst, size); + std::unique_ptr io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + io = std::make_unique(dst, size, mem_storage[seq_id]); + } else { + io = std::make_unique(dst, size); + } + try { - return state_seq_write_data(io, seq_id, flags); + io->write(&io_magic, sizeof(io_magic)); + io->write(&seq_id, sizeof(seq_id)); + + return state_seq_write_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -2413,9 +2678,38 @@ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, siz } size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { - llama_io_read_buffer io(src, size); + std::unique_ptr io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + // create a temporary io to read the magic and the src seq_id + io = std::make_unique(src, size); + + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + GGML_ASSERT(mem_storage.find(seq_id_read) != mem_storage.end()); + + io = std::make_unique(src, size, mem_storage[seq_id_read]); + } else { + io = std::make_unique(src, size); + } + try { - return state_seq_read_data(io, seq_id, flags); + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + return state_seq_read_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -3406,7 +3700,6 @@ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t s return ctx->state_seq_get_data(seq_id, dst, size, flags); } - size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 53c705eaffc..92d1b0cf95a 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -23,6 +23,21 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +// stores copy of the memory in device buffer. used for fast state save/load +struct llama_memory_buffer { + int n_tensors = 0; + size_t total_size = 0; + + ggml_backend_buffer_ptr buf; + + ggml_context_ptr ctx; + + std::vector org; + std::vector cpy; +}; + +using llama_memory_buffers = std::map; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -128,6 +143,7 @@ struct llama_context { size_t state_set_data(const uint8_t * src, size_t size); size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); @@ -328,6 +344,9 @@ struct llama_context { // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; + // keep copies of the per-sequence memory on the device + std::map mem_storage; + bool has_evaluated_once = false; // env: LLAMA_GRAPH_REUSE_DISABLE diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 2ff23f87cf4..fe155c92dea 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -65,8 +65,13 @@ static ggml_tensor * ggml_mul_mat_aux( ggml_tensor * res; - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + if (!ggml_is_contiguous(cur)) { + res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n); + } else { + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + } res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); return res; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index ac7f9ee8650..0160a89caa2 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -166,6 +166,8 @@ struct llama_hparams { float f_attn_out_scale = 0.0f; uint32_t attn_temp_length = 0; + float f_attn_value_scale = 0.0f; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; diff --git a/examples/talk-llama/llama-io.cpp b/examples/talk-llama/llama-io.cpp index 7ad70d16334..5ec4634943f 100644 --- a/examples/talk-llama/llama-io.cpp +++ b/examples/talk-llama/llama-io.cpp @@ -1,5 +1,7 @@ #include "llama-io.h" +#include + void llama_io_write_i::write_string(const std::string & str) { uint32_t str_size = str.size(); @@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) { void llama_io_read_i::read_string(std::string & str) { uint32_t str_size; - read_to(&str_size, sizeof(str_size)); + read(&str_size, sizeof(str_size)); + + std::vector buf(str_size); + read(buf.data(), str_size); - str.assign((const char *) read(str_size), str_size); + str.assign(buf.data(), str_size); } diff --git a/examples/talk-llama/llama-io.h b/examples/talk-llama/llama-io.h index ce9216b83b1..f276af4fb96 100644 --- a/examples/talk-llama/llama-io.h +++ b/examples/talk-llama/llama-io.h @@ -12,7 +12,7 @@ class llama_io_write_i { virtual ~llama_io_write_i() = default; virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes written so far virtual size_t n_bytes() = 0; @@ -25,8 +25,8 @@ class llama_io_read_i { llama_io_read_i() = default; virtual ~llama_io_read_i() = default; - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; + virtual void read(void * dst, size_t size) = 0; + virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes read so far virtual size_t n_bytes() = 0; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 09102f549c8..a49a055a630 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -67,6 +67,7 @@ static ggml_tensor * ggml_mul_mat_aux( res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); return res; @@ -1900,14 +1901,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; - io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + io.read(&n_stream_cur, sizeof(n_stream_cur)); if (n_stream_cur != n_stream) { throw std::runtime_error("n_stream mismatch"); } for (uint32_t s = 0; s < n_stream; ++s) { uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); if (cell_count == 0) { continue; @@ -2082,8 +2083,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -2092,7 +2093,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); ubatch.pos[i + ubatch.n_tokens] = ext.y; ubatch.pos[i + ubatch.n_tokens*2] = ext.x; @@ -2101,7 +2102,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); } ubatch.pos[i] = pos; @@ -2143,20 +2144,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cells.pos_set(i, pos); if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); cells.ext_set(i, ext); } for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); @@ -2189,8 +2190,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 uint32_t v_trans; uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&v_trans, sizeof(v_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != layers.size()) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); @@ -2217,7 +2218,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of key int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + io.read(&k_type_i_ref, sizeof(k_type_i_ref)); const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); @@ -2226,7 +2227,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of key uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + io.read(&k_size_row_ref, sizeof(k_size_row_ref)); const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); @@ -2236,13 +2237,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * k_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; - ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + io.read_tensor(k, dst_offset, k_size_row); } } } @@ -2261,7 +2261,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2270,7 +2270,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of value uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + io.read(&v_size_row_ref, sizeof(v_size_row_ref)); const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); @@ -2280,13 +2280,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * v_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + io.read_tensor(v, dst_offset, v_size_row); } } } @@ -2305,7 +2304,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2314,7 +2313,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read element size of value uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + io.read(&v_size_el_ref, sizeof(v_size_el_ref)); const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); @@ -2323,7 +2322,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read GQA embedding size uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); if (n_embd_v_gqa != n_embd_v_gqa_ref) { LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); return false; @@ -2335,15 +2334,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t h = sinfo.head(); for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (h + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + io.read_tensor(v, dst_offset, cell_count * v_size_el); } } else { // Slow path: scatter to non-contiguous positions for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const void * src = io.read(cell_count * v_size_el); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + io.read_tensor(v, dst_offset, v_size_el); } } } diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 9287fe45e96..c07f1d969cb 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -726,6 +726,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } + if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); + } + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count uint32_t cell_count_check = 0; for (const auto & range : cell_ranges) { @@ -743,7 +747,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i GGML_UNUSED(flags); uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); bool res = true; @@ -784,7 +788,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_layer = hparams.n_layer; io.write(&s_trans, sizeof(s_trans)); - io.write(&n_layer, sizeof(n_layer)); + io.write(&n_layer, sizeof(n_layer)); // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time @@ -879,8 +883,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -920,14 +924,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cell.pos = pos; for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); @@ -961,8 +965,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t s_trans; uint32_t n_layer; - io.read_to(&s_trans, sizeof(s_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&s_trans, sizeof(s_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != hparams.n_layer) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); @@ -984,7 +988,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of key int32_t r_type_i_ref; - io.read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + io.read(&r_type_i_ref, sizeof(r_type_i_ref)); const int32_t r_type_i = (int32_t) r_l[il]->type; if (r_type_i != r_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); @@ -993,7 +997,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t r_size_row_ref; - io.read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + io.read(&r_size_row_ref, sizeof(r_size_row_ref)); const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); if (r_size_row != r_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); @@ -1002,7 +1006,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row); + io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row); } } @@ -1013,7 +1017,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { @@ -1023,7 +1027,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t s_size_row_ref; - io.read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + io.read(&s_size_row_ref, sizeof(s_size_row_ref)); const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); if (s_size_row != s_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); @@ -1032,7 +1036,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row); + io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row); } } } else { @@ -1045,7 +1049,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); @@ -1054,7 +1058,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t s_size_el_ref; - io.read_to(&s_size_el_ref, sizeof(s_size_el_ref)); + io.read(&s_size_el_ref, sizeof(s_size_el_ref)); const size_t s_size_el = ggml_type_size(s_l[il]->type); if (s_size_el != s_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il); @@ -1063,7 +1067,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read state embedding size uint32_t n_embd_s_ref; - io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref)); + io.read(&n_embd_s_ref, sizeof(n_embd_s_ref)); if (n_embd_s != n_embd_s_ref) { LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il); return false; @@ -1073,7 +1077,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_s; ++j) { const size_t dst_offset = (head + j * size) * s_size_el; - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el); + io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el); } } } diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 26864c18e97..e83056557bf 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -268,6 +268,7 @@ void llama_model_saver::add_kv_from_model() { // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale); add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 9e2a13cbd43..ff30a2ae7a6 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -34,6 +34,285 @@ #include #include +static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params & params) { + switch (arch) { + case LLM_ARCH_LLAMA: + return new llama_model_llama(params); + case LLM_ARCH_LLAMA4: + return new llama_model_llama4(params); + case LLM_ARCH_LLAMA_EMBED: + return new llama_model_llama_embed(params); + case LLM_ARCH_MAINCODER: + return new llama_model_maincoder(params); + case LLM_ARCH_DECI: + return new llama_model_deci(params); + case LLM_ARCH_BAICHUAN: + return new llama_model_baichuan(params); + case LLM_ARCH_FALCON: + return new llama_model_falcon(params); + case LLM_ARCH_GROK: + return new llama_model_grok(params); + case LLM_ARCH_STARCODER: + return new llama_model_starcoder(params); + case LLM_ARCH_REFACT: + return new llama_model_refact(params); + case LLM_ARCH_BERT: + return new llama_model_bert(params); + case LLM_ARCH_JINA_BERT_V2: + return new llama_model_jina_bert_v2(params); + case LLM_ARCH_JINA_BERT_V3: + return new llama_model_jina_bert_v3(params); + case LLM_ARCH_NOMIC_BERT: + return new llama_model_nomic_bert(params); + case LLM_ARCH_NOMIC_BERT_MOE: + return new llama_model_nomic_bert_moe(params); + case LLM_ARCH_MODERN_BERT: + return new llama_model_modern_bert(params); + case LLM_ARCH_NEO_BERT: + return new llama_model_neo_bert(params); + case LLM_ARCH_EUROBERT: + return new llama_model_eurobert(params); + case LLM_ARCH_BLOOM: + return new llama_model_bloom(params); + case LLM_ARCH_MPT: + return new llama_model_mpt(params); + case LLM_ARCH_STABLELM: + return new llama_model_stablelm(params); + case LLM_ARCH_QWEN: + return new llama_model_qwen(params); + case LLM_ARCH_QWEN2: + return new llama_model_qwen2(params); + case LLM_ARCH_DREAM: + return new llama_model_dream(params); + case LLM_ARCH_LLADA: + return new llama_model_llada(params); + case LLM_ARCH_LLADA_MOE: + return new llama_model_llada_moe(params); + case LLM_ARCH_RND1: + return new llama_model_rnd1(params); + case LLM_ARCH_QWEN2VL: + return new llama_model_qwen2vl(params); + case LLM_ARCH_QWEN2MOE: + return new llama_model_qwen2moe(params); + case LLM_ARCH_QWEN3: + return new llama_model_qwen3(params); + case LLM_ARCH_QWEN3MOE: + return new llama_model_qwen3moe(params); + case LLM_ARCH_QWEN3VL: + return new llama_model_qwen3vl(params); + case LLM_ARCH_QWEN3VLMOE: + return new llama_model_qwen3vlmoe(params); + case LLM_ARCH_PHI2: + return new llama_model_phi2(params); + case LLM_ARCH_PHI3: + return new llama_model_phi3(params); + case LLM_ARCH_PHIMOE: + return new llama_model_phimoe(params); + case LLM_ARCH_PLAMO: + return new llama_model_plamo(params); + case LLM_ARCH_PLAMO2: + return new llama_model_plamo2(params); + case LLM_ARCH_PLAMO3: + return new llama_model_plamo3(params); + case LLM_ARCH_GPT2: + return new llama_model_gpt2(params); + case LLM_ARCH_CODESHELL: + return new llama_model_codeshell(params); + case LLM_ARCH_ORION: + return new llama_model_orion(params); + case LLM_ARCH_INTERNLM2: + return new llama_model_internlm2(params); + case LLM_ARCH_MINICPM3: + return new llama_model_minicpm3(params); + case LLM_ARCH_GEMMA: + return new llama_model_gemma(params); + case LLM_ARCH_GEMMA2: + return new llama_model_gemma2(params); + case LLM_ARCH_GEMMA3: + return new llama_model_gemma3(params); + case LLM_ARCH_GEMMA3N: + return new llama_model_gemma3n(params); + case LLM_ARCH_GEMMA4: + return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA_EMBEDDING: + return new llama_model_gemma_embedding(params); + case LLM_ARCH_STARCODER2: + return new llama_model_starcoder2(params); + case LLM_ARCH_MAMBA: + return new llama_model_mamba(params); + case LLM_ARCH_MAMBA2: + return new llama_model_mamba2(params); + case LLM_ARCH_JAMBA: + return new llama_model_jamba(params); + case LLM_ARCH_XVERSE: + return new llama_model_xverse(params); + case LLM_ARCH_COMMAND_R: + return new llama_model_command_r(params); + case LLM_ARCH_COHERE2: + return new llama_model_cohere2(params); + case LLM_ARCH_DBRX: + return new llama_model_dbrx(params); + case LLM_ARCH_OLMO: + return new llama_model_olmo(params); + case LLM_ARCH_OLMO2: + return new llama_model_olmo2(params); + case LLM_ARCH_OLMOE: + return new llama_model_olmoe(params); + case LLM_ARCH_OPENELM: + return new llama_model_openelm(params); + case LLM_ARCH_GPTNEOX: + return new llama_model_gptneox(params); + case LLM_ARCH_ARCTIC: + return new llama_model_arctic(params); + case LLM_ARCH_DEEPSEEK: + return new llama_model_deepseek(params); + case LLM_ARCH_DEEPSEEK2: + return new llama_model_deepseek2(params); + case LLM_ARCH_DEEPSEEK2OCR: + return new llama_model_deepseek2ocr(params); + case LLM_ARCH_GLM_DSA: + return new llama_model_glm_dsa(params); + case LLM_ARCH_MISTRAL4: + return new llama_model_mistral4(params); + case LLM_ARCH_CHATGLM: + return new llama_model_chatglm(params); + case LLM_ARCH_GLM4: + return new llama_model_glm4(params); + case LLM_ARCH_GLM4_MOE: + return new llama_model_glm4_moe(params); + case LLM_ARCH_BITNET: + return new llama_model_bitnet(params); + case LLM_ARCH_T5: + return new llama_model_t5(params); + case LLM_ARCH_T5ENCODER: + return new llama_model_t5encoder(params); + case LLM_ARCH_JAIS: + return new llama_model_jais(params); + case LLM_ARCH_JAIS2: + return new llama_model_jais2(params); + case LLM_ARCH_NEMOTRON: + return new llama_model_nemotron(params); + case LLM_ARCH_NEMOTRON_H: + return new llama_model_nemotron_h(params); + case LLM_ARCH_NEMOTRON_H_MOE: + return new llama_model_nemotron_h_moe(params); + case LLM_ARCH_EXAONE: + return new llama_model_exaone(params); + case LLM_ARCH_EXAONE4: + return new llama_model_exaone4(params); + case LLM_ARCH_EXAONE_MOE: + return new llama_model_exaone_moe(params); + case LLM_ARCH_RWKV6: + return new llama_model_rwkv6(params); + case LLM_ARCH_RWKV6QWEN2: + return new llama_model_rwkv6qwen2(params); + case LLM_ARCH_RWKV7: + return new llama_model_rwkv7(params); + case LLM_ARCH_ARWKV7: + return new llama_model_arwkv7(params); + case LLM_ARCH_GRANITE: + return new llama_model_granite(params); + case LLM_ARCH_GRANITE_MOE: + return new llama_model_granite_moe(params); + case LLM_ARCH_MINICPM: + return new llama_model_minicpm(params); + case LLM_ARCH_GRANITE_HYBRID: + return new llama_model_granite_hybrid(params); + case LLM_ARCH_CHAMELEON: + return new llama_model_chameleon(params); + case LLM_ARCH_WAVTOKENIZER_DEC: + return new llama_model_wavtokenizer_dec(params); + case LLM_ARCH_PLM: + return new llama_model_plm(params); + case LLM_ARCH_BAILINGMOE: + return new llama_model_bailingmoe(params); + case LLM_ARCH_BAILINGMOE2: + return new llama_model_bailingmoe2(params); + case LLM_ARCH_SEED_OSS: + return new llama_model_seed_oss(params); + case LLM_ARCH_DOTS1: + return new llama_model_dots1(params); + case LLM_ARCH_ARCEE: + return new llama_model_arcee(params); + case LLM_ARCH_AFMOE: + return new llama_model_afmoe(params); + case LLM_ARCH_ERNIE4_5: + return new llama_model_ernie4_5(params); + case LLM_ARCH_ERNIE4_5_MOE: + return new llama_model_ernie4_5_moe(params); + case LLM_ARCH_PADDLEOCR: + return new llama_model_paddleocr(params); + case LLM_ARCH_HUNYUAN_MOE: + return new llama_model_hunyuan_moe(params); + case LLM_ARCH_HUNYUAN_VL: + return new llama_model_hunyuan_vl(params); + case LLM_ARCH_HUNYUAN_DENSE: + return new llama_model_hunyuan_dense(params); + case LLM_ARCH_SMOLLM3: + return new llama_model_smollm3(params); + case LLM_ARCH_OPENAI_MOE: + return new llama_model_openai_moe(params); + case LLM_ARCH_FALCON_H1: + return new llama_model_falcon_h1(params); + case LLM_ARCH_LFM2: + return new llama_model_lfm2(params); + case LLM_ARCH_LFM2MOE: + return new llama_model_lfm2moe(params); + case LLM_ARCH_SMALLTHINKER: + return new llama_model_smallthinker(params); + case LLM_ARCH_GROVEMOE: + return new llama_model_grovemoe(params); + case LLM_ARCH_APERTUS: + return new llama_model_apertus(params); + case LLM_ARCH_MINIMAX_M2: + return new llama_model_minimax_m2(params); + case LLM_ARCH_COGVLM: + return new llama_model_cogvlm(params); + case LLM_ARCH_PANGU_EMBED: + return new llama_model_pangu_embed(params); + case LLM_ARCH_QWEN3NEXT: + return new llama_model_qwen3next(params); + case LLM_ARCH_QWEN35: + return new llama_model_qwen35(params); + case LLM_ARCH_QWEN35MOE: + return new llama_model_qwen35moe(params); + case LLM_ARCH_MISTRAL3: + return new llama_model_mistral3(params); + case LLM_ARCH_MIMO2: + return new llama_model_mimo2(params); + case LLM_ARCH_KIMI_LINEAR: + return new llama_model_kimi_linear(params); + case LLM_ARCH_STEP35: + return new llama_model_step35(params); + default: + throw std::runtime_error(std::string("unsupported model architecture: '") + llm_arch_name(arch) + "'"); + } + +} + +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params) { + llama_model * model = llama_model_mapping(arch, params); + + if (model != nullptr) { + model->arch = arch; + auto & devices = model->devices; + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } + } + + return model; +} + +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params) { + llm_arch arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } + + return llama_model_create(arch, params); +} + struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; const llama_hparams & hparams = ud->model->hparams; @@ -688,22 +967,12 @@ llama_model::~llama_model() { } } -void llama_model::load_stats(llama_model_loader & ml) { +void llama_model_base::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; pimpl->n_bytes = ml.n_bytes; } -void llama_model::load_arch(llama_model_loader & ml) { - arch = ml.get_arch(); - if (arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); - } - if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { - throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); - } -} - -void llama_model::load_hparams(llama_model_loader & ml) { +void llama_model_base::load_hparams(llama_model_loader & ml) { const gguf_context * ctx = ml.metadata; // get metadata as string @@ -862,8215 +1131,931 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false); } - // for differentiating model types - uint32_t n_vocab = 0; - ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); - // for classifier models ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); if (!classifier_labels.empty()) { hparams.n_cls_out = classifier_labels.size(); } - // arch-specific KVs - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // per-arch hparams + load_arch_hparams(ml); - if (hparams.n_expert == 8) { - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8x7B; break; - case 56: type = LLM_TYPE_8x22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B - case 22: type = LLM_TYPE_1B; break; - case 26: type = LLM_TYPE_3B; break; - case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B - case 30: type = LLM_TYPE_256M; break; // smoldocling 256M - // granite uses a vocab with len 49152 - case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; - case 36: type = LLM_TYPE_8B; break; // granite - case 40: type = LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_34B; break; - case 60: type = LLM_TYPE_30B; break; - case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } - } break; - case LLM_ARCH_LLAMA4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa == 0) { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope - } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; - hparams.n_attn_temp_floor_scale = 8192; - hparams.f_attn_temp_scale = 0.1f; - hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } + pimpl->n_bytes = ml.n_bytes; - switch (hparams.n_expert) { - case 0: { - // MobileLLM (no MoE) - switch (hparams.n_embd) { - case 2048: type = LLM_TYPE_140M; break; - case 4096: type = LLM_TYPE_360M; break; - case 6144: type = LLM_TYPE_950M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case 16: type = LLM_TYPE_17B_16E; break; - case 128: type = LLM_TYPE_17B_128E; break; - default: type = LLM_TYPE_UNKNOWN; - } + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); - hparams.use_kq_norm = type != LLM_TYPE_17B_128E; - } break; - case LLM_ARCH_ARCEE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } - // Arcee uses the same structure as Llama - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_AFMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - // Set up interleaved sliding window attention (ISWA) - // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) - if (hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + hparams.rope_type = llama_model_rope_type(this); +} - // Default to sigmoid if not set - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +void llama_model_base::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - switch (hparams.n_layer) { - case 56: type = LLM_TYPE_6B; break; - case 32: type = LLM_TYPE_26B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DECI: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - case 162: type = LLM_TYPE_405B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM: - { - // Backward-compatible defaults for older MiniCPM GGUFs - hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); - hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + vocab.load(ml, kv); +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +bool llama_model_base::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // Optional KV reads, override defaults if present in newer GGUF exports - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + const int n_layer = hparams.n_layer; + const int n_gpu_layers = this->n_gpu_layers(); - // MiniCPM uses rope by default, unlike Granite which uses it as a switch - hparams.rope_finetuned = true; + const bool use_mmap_buffer = true; - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_1B; break; - case 40: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + this->ml = &ml; // to be used by create_tensor() and load_arch_tensors() - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROK: - { - // defaults for old GGUFs - hparams.yarn_beta_fast = 8.0f; - hparams.f_logit_scale = 0.5773502691896257f; - hparams.f_embedding_scale = 78.38367176906169f; - hparams.f_attn_out_scale = 0.08838834764831845f; - hparams.f_attn_logit_softcapping = 30.0f; - hparams.f_router_logit_softcapping = 30.0f; - // no final_logit_softcapping in grok-1 - hparams.f_final_logit_softcapping = 0.0f; - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); - ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_314B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 60: type = LLM_TYPE_40B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAICHUAN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); + } - if (type == LLM_TYPE_13B) { - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } - } break; - case LLM_ARCH_STARCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - case 42: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_REFACT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 3: - type = LLM_TYPE_17M; break; // bge-micro - case 6: - type = LLM_TYPE_22M; break; // MiniLM-L6 - case 12: - switch (hparams.n_embd) { - case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small - case 768: type = LLM_TYPE_109M; break; // bge-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - type = LLM_TYPE_335M; break; // bge-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MODERN_BERT: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - uint32_t swa_period = 3; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period, true); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i].dev; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 12: - type = LLM_TYPE_47M; break; // granite-embedding-small - case 22: - type = LLM_TYPE_149M; break; // modern-bert-base - case 28: - type = LLM_TYPE_395M; break; // modern-bert-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - hparams.f_max_alibi_bias = 8.0f; - - switch (hparams.n_layer) { - case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small - case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: - type = LLM_TYPE_558M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - - if (hparams.n_layer == 12 && hparams.n_embd == 768) { - if (arch == LLM_ARCH_NOMIC_BERT) { - type = LLM_TYPE_137M; - } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { - type = LLM_TYPE_475M; - } - } - } break; - case LLM_ARCH_NEO_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - if (hparams.n_layer == 28) { - type = LLM_TYPE_250M; - } - } break; - case LLM_ARCH_EUROBERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - if (hparams.n_layer == 12) { - type = LLM_TYPE_SMALL; // 0.2B - } - } break; - case LLM_ARCH_BLOOM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 30: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_MPT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_30B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_STABLELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN2VL: - { - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); } - // fall through - case LLM_ARCH_QWEN2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; - case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; - case 32: type = LLM_TYPE_7B; break; - case 36: type = LLM_TYPE_3B; break; - case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DREAM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { - case 28: - type = LLM_TYPE_7B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_LLADA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { - case 32: - type = LLM_TYPE_8B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_LLADA_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // diffusion language model uses non-causal attention - hparams.causal_attn = false; - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RND1: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_QWEN2MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_A2_7B; break; - case 28: type = LLM_TYPE_57B_A14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAINCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VL: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VLMOE: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } + const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu).dev; + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - if (found_swa && hparams.n_swa > 0) { - LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } - // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` - hparams.swa_type = LLAMA_SWA_TYPE_NONE; + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); - hparams.n_swa = 0; - hparams.set_swa_pattern(1); - } - } break; - case LLM_ARCH_PHIMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_16x3_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // create tensors for the weights + { + // TODO: move to a separate function + const auto tn = LLM_TN(arch); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; - // Load Mamba SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + if (n_expert > 0 && n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + layers.resize(n_layer); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: - if (hparams.n_embd == 2048) { - type = LLM_TYPE_2B; - } else if (hparams.n_embd == 4096) { - type = LLM_TYPE_8B; - } - break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - uint32_t swa_period = 8; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // call the per-model loading function + load_arch_tensors(ml); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GPT2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 12: type = LLM_TYPE_SMALL; break; - case 24: type = LLM_TYPE_MEDIUM; break; - case 36: type = LLM_TYPE_LARGE; break; - case 48: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CODESHELL: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 42: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ORION: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) + // this avoids having to add scale loading to every architecture + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_INTERNLM2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // attention weight scales (per-tensor, shape {1}) + if (!layer.wq_s && layer.wq) { + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_s && layer.wk) { + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_s && layer.wv) { + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_s && layer.wo) { + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_2B; break; - case 28: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; // default value of gemma 2 - uint32_t swa_period = 2; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.attn_soft_cap = true; - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_2B; break; - case 42: type = LLM_TYPE_9B; break; - case 46: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } - - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); - } break; - case LLM_ARCH_GEMMA3: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 6; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // dense FFN weight scales (per-tensor, shape {1}) + if (!layer.ffn_gate_s && layer.ffn_gate) { + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_s && layer.ffn_down) { + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_s && layer.ffn_up) { + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - hparams.f_final_logit_softcapping = 0.0f; - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_270M; break; - case 26: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_8B; break; // Rnj-1 - case 34: type = LLM_TYPE_4B; break; - case 48: type = LLM_TYPE_12B; break; - case 62: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // MoE expert weight scales (per-expert, shape {n_expert}) + if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { + layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { + layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); - } break; - case LLM_ARCH_GEMMA3N: - { - uint32_t swa_period = 5; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(swa_period); - - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_E2B; break; - case 35: type = LLM_TYPE_E4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA4: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - - uint32_t n_kv_shared_layers = 0; - ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); - - hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; - hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_26B_A4B; break; - case 35: type = LLM_TYPE_E2B; break; - case 42: type = LLM_TYPE_E4B; break; - case 60: type = LLM_TYPE_31B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - uint32_t swa_period = 6; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - hparams.causal_attn = false; // embeddings do not use causal attention + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + } - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.done_getting_tensors(); - //applied only if model converted with --sentence-transformers-dense-modules - ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); - ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } - GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); - GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); + // create the backend buffers + std::vector> ctx_buf_maps; + ctx_buf_maps.reserve(ml.ctx_map.size()); - } break; - case LLM_ARCH_STARCODER2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_3B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - case 52: type = LLM_TYPE_20B; break; // granite - case 88: type = LLM_TYPE_34B; break; // granite - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA2: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + // Ensure we have enough capacity for the maximum backend buffer we will potentially create + const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); + pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (auto & [buft, ctx_ptr] : ml.ctx_map) { + ggml_context * ctx = ctx_ptr.get(); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + // skip contexts without tensors + if (ggml_get_first_tensor(ctx) == nullptr) { + continue; + } - switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. - case 12: // 900M 8x???M - case 32: // 51B 16x?B - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_XVERSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 80: type = LLM_TYPE_65B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COMMAND_R: - { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_35B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COHERE2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DBRX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + llama_buf_map buf_map; + buf_map.reserve(n_max_backend_buffer); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; + // check if it is possible to use buffer_from_host_ptr with this buffer type + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); } - } break; - case LLM_ARCH_OLMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - - switch (hparams.n_layer) { - case 22: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + } + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; + bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SEED_OSS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_36B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; + std::vector bufs; + if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + GGML_ASSERT(!ml.no_alloc); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + // only the mmap region containing the tensors in the model is mapped to the backend buffer + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, + // then we could just use metal for all layers + // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size + void * addr = nullptr; + size_t first, last; // NOLINT + ml.get_mapping_range(&first, &last, &addr, idx, ctx); + if (first >= last) { + continue; } - } break; - case LLM_ARCH_OPENELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; + const size_t max_size = ggml_get_max_tensor_size(ctx); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - } break; - case LLM_ARCH_GPTNEOX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { - case 6: - switch (hparams.n_ff()) { - case 512: type = LLM_TYPE_14M; break; - case 2048: type = LLM_TYPE_70M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_160M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 16: - switch (hparams.n_ff()) { - case 8192: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_410M; break; - case 8192: type = LLM_TYPE_1_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_ff()) { - case 10240: type = LLM_TYPE_2_8B; break; - case 16384: type = LLM_TYPE_6_9B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 36: - switch (hparams.n_ff()) { - case 20480: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 44: - switch (hparams.n_ff()) { - case 24576: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; + bufs.emplace_back(buf); + buf_map.emplace(idx, buf); + } + } else { + ggml_backend_buffer_t buf; + if (ml.no_alloc) { + buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them } - } break; - case LLM_ARCH_ARCTIC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + if (use_mlock && ggml_backend_buffer_is_host(buf)) { + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); + } + bufs.emplace_back(buf); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + buf_map.emplace(idx, buf); + } + } - if (hparams.n_expert == 128) { - switch (hparams.n_layer) { - case 35: type = LLM_TYPE_10B_128x3_66B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - - switch (hparams.n_ff_exp) { - case 1408: type = LLM_TYPE_16B; break; - case 1792: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_MISTRAL4: - { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + for (auto & buf : bufs) { + // indicate that this buffer contains weights + // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - if (!is_lite) { - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - // for compatibility with existing DeepSeek V2 and V2.5 GGUFs - // that have no expert_gating_func model parameter set - if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { - // GLM 4.7 Lite - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } else { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; - } - } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { - // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - // cancel the factor from the convert script - hparams.rope_yarn_log_mul /= 0.1f; - } + ctx_buf_maps.emplace_back(ctx, buf_map); + } - // (optional) temperature tuning - used by mistral-large - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? + if (llama_supports_gpu_offload()) { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - hparams.f_attn_temp_offset = 0.0f; + int n_repeating = n_gpu; + if (n_repeating > 0) { + LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; + } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - switch (hparams.n_layer) { - case 27: type = LLM_TYPE_16B; break; - case 47: type = LLM_TYPE_30B_A3B; break; - case 60: type = LLM_TYPE_236B; break; - case 61: type = LLM_TYPE_671B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK2OCR: - { - // similar to deepseek2, but without MLA - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; - } + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; - switch (hparams.n_layer) { - case 12: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CHATGLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: { - if (hparams.n_head(0) == 16) { - type = LLM_TYPE_1_5B; - } else { - type = LLM_TYPE_6B; - } - } break; - case 40: { - if (hparams.n_head(0) == 24) { - type = LLM_TYPE_4B; - } else { - type = LLM_TYPE_9B; - } - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // NextN/MTP parameters (GLM-OCR) - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 17: type = LLM_TYPE_1B; break; // GLM-OCR - case 40: type = LLM_TYPE_9B; break; - case 61: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - // Expert gating function (GLM-4.5 uses sigmoid) - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } - - // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) - case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM_DSA: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - // deepseek MLA parameters - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - - // DSA parameters - ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); - ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); - ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); - - // Expert gating function (GLM-4.5 uses sigmoid) - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } + // print memory requirements per buffer type + for (auto & [_, bufs] : pimpl->ctxs_bufs) { + for (auto & buf: bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } + } - // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + if (ml.no_alloc) { + return true; + } - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + // load tensor data + for (auto & [ctx, buf_map] : ctx_buf_maps) { + if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + return false; + } + } - switch (hparams.n_layer) { - case 79: type = LLM_TYPE_744B_A40B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BITNET: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (use_mmap_buffer) { + for (auto & mapping : ml.mappings) { + pimpl->mappings.emplace_back(std::move(mapping)); + } + } - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + return true; +} - uint32_t dec_start_token_id; - if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { - hparams.dec_start_token_id = dec_start_token_id; - } +ggml_tensor * llama_model_base::create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + return ml.create_tensor( + hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, + tn, ne, flags); +} - hparams.dec_n_layer = hparams.n_layer; - ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - - switch (hparams.n_layer) { - case 6: type = LLM_TYPE_60M; break; // t5-small - case 8: type = LLM_TYPE_80M; break; // flan-t5-small - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_220M; break; // t5-base - case 2048: type = LLM_TYPE_250M; break; // flan-t5-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_770M; break; // t5-large - case 2816: type = LLM_TYPE_780M; break; // flan-t5-large - case 16384: type = LLM_TYPE_3B; break; // t5-3b - case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl - case 65536: type = LLM_TYPE_11B; break; // t5-11b - case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5ENCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - type = LLM_TYPE_UNKNOWN; - } break; - case LLM_ARCH_JAIS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_3B; break; - case 40: type = LLM_TYPE_13B; break; - /* TODO: add variants */ - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JAIS2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +std::string llama_model::arch_name() const { + return llm_arch_name(arch); +} - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - case 68: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // A layer is recurrent IFF the n_head_kv value is set to 0 and - // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); - } +std::string llama_model::type_name() const { + return llm_type_name(type); +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +std::string llama_model::desc() const { + return pimpl->desc_str; +} - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); +size_t llama_model::size() const { + return pimpl->n_bytes; +} - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B - case 56: type = LLM_TYPE_9B; break; - case 88: type = LLM_TYPE_120B_A12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +size_t llama_model::n_tensors() const { + return tensors_by_name.size(); +} - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.n_layer == 64) { // 32B - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } +size_t llama_model::n_devices() const { + return devices.size(); +} - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +const float * llama_model::tensor_split() const { + return params.tensor_split; +} - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_1_2B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 128; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_30B_A3B; break; - case 48: - case 49: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_6B; break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); - ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); - ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); - ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 12: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_190M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_450M; break; - case 2048: type = LLM_TYPE_1_5B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 28: - switch (hparams.n_embd) { - case 1536: type = LLM_TYPE_1_5B; break; - case 3584: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_2_9B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: - switch (hparams.n_embd) { - case 4096: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_3B; break; - // Add additional layer/vocab/etc checks here for other model sizes - default: type = LLM_TYPE_UNKNOWN; - } +uint32_t llama_model::n_gpu_layers() const { + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; +} - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); - - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } +llama_split_mode llama_model::split_mode() const { + return params.split_mode; +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +std::map llama_model::memory_breakdown() const { + std::map ret; + for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { + if (hparams.no_alloc) { + GGML_ASSERT(bufs.size() == 1); + ggml_backend_buffer_t buf = bufs[0].get(); + GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); + } else { + for (const auto & buf : bufs) { + // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } + } + } + return ret; +} - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; - case 2048: case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_CHAMELEON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default - ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_34B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - } break; - case LLM_ARCH_BAILINGMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_16B; break; - case 88: type = LLM_TYPE_290B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; - case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DOTS1: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_142B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - case LLM_ARCH_PADDLEOCR: - { - // paddleocr need mrope_section - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (arch == LLM_ARCH_ERNIE4_5_MOE) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - } +void llama_model::print_info() const { + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_0_3B; break; - case 28: type = LLM_TYPE_21B_A3B; break; - case 54: type = LLM_TYPE_300B_A47B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common parameters - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); - - switch (hparams.n_layer) { - case 36: - type = LLM_TYPE_0_5B; break; - case 24: - type = LLM_TYPE_1_5B; break; - case 66: - type = LLM_TYPE_1B; break; - case 32: - type = LLM_TYPE_3B; break; - case 44: - type = LLM_TYPE_7B; break; - case 72: - type = LLM_TYPE_34B; break; - default: - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + auto print_f = [](const std::function & f, uint32_t n) { + bool is_var = false; - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_A13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) - if (hparams.rope_scaling_alpha > 0.0f) { - const int dim = hparams.n_embd_head_k(); - hparams.rope_freq_base_train = hparams.rope_freq_base_train - * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); - } + std::vector v; + for (uint32_t i = 0; i < n; ++i) { + v.push_back(f(i)); + if (v[i] != v[0]) { + is_var = true; + } + } - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_0_5B; break; - case 2048: type = LLM_TYPE_1_8B; break; - case 3072: type = LLM_TYPE_4B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SMOLLM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.n_no_rope_layer_step = 4; + std::stringstream ss; - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 2; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_20B; break; - case 36: type = LLM_TYPE_120B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_LFM2: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; - } - hparams.n_layer_dense_lead = hparams.n_layer; - switch (hparams.n_ff()) { - case 4608: type = LLM_TYPE_350M; break; - case 6912: type = LLM_TYPE_700M; break; - case 8192: type = LLM_TYPE_1_2B; break; - case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; - } - } - } break; - case LLM_ARCH_LFM2MOE: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + if (is_var) { + ss << "["; + for (uint32_t i = 0; i < n; ++i) { + ss << v[i]; + if (i < n - 1) { + ss << ", "; } + } + ss << "]"; + } else { + ss << v[0]; + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_8B_A1B; break; - case 40: type = LLM_TYPE_24B_A2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SMALLTHINKER: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period, true); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; - } + return ss.str(); + }; - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + // hparams + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - case 52: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROVEMOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); - ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); - ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_APERTUS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (!hparams.vocab_only) { + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); + LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); + LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); + } + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + // MRoPE (Multi-axis Rotary Position Embedding) sections + if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + } + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_230B_A10B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COGVLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 - case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + size_t i = 0; + for (const auto & label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_80B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN35: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - switch (hparams.n_layer) { - case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; - case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; - case 64: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN35MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_35B_A3B; break; - case 48: type = LLM_TYPE_122B_A10B; break; - case 60: type = LLM_TYPE_397B_A17B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MISTRAL3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - hparams.f_attn_temp_offset = 0.0f; + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - // TODO: maybe add n_attn_temp_floor_scale as a separate KV? - if (hparams.f_attn_temp_scale != 0.0f) { - hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; - if (hparams.n_attn_temp_floor_scale == 0) { - throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); - } - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - case 34: type = LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MIMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_310B_A15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_KIMI_LINEAR: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); - - // MLA qk_rope_head_dim (for reference) - // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 - - // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) - // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - // MoE parameters - Kimi uses moe_intermediate_size = 1024 - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - switch (hparams.n_layer) { - case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_STEP35: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } + } - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + vocab.print_info(); +} - // full_attention layer only use half of the RoPE dimensions - hparams.n_rot_full = hparams.n_rot_full / 2; +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; +} - // MoE + SWA parameters - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} - // Step35 uses sigmoid gating by default (if not set in GGUF) - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +template +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } - switch (hparams.n_layer) { - case 45: type = LLM_TYPE_196B_A11B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - default: throw std::runtime_error("unsupported model architecture: " + arch_name()); + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } } - pimpl->n_bytes = ml.n_bytes; + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + return op_supported; +} - if (hparams.f_max_alibi_bias > 0.0f) { - hparams.use_alibi = true; +template +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } } - hparams.rope_type = llama_model_rope_type(this); + throw std::runtime_error(format("no suitable buffer type found")); } -void llama_model::load_vocab(llama_model_loader & ml) { - const auto kv = LLM_KV(arch); +ggml_backend_buffer_type_t llama_model::select_buft(int il) const { + return ::select_buft( + *pimpl->dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} - vocab.load(ml, kv); +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; } -bool llama_model::load_tensors(llama_model_loader & ml) { - const auto & split_mode = params.split_mode; - const auto & use_mlock = params.use_mlock; - const auto & tensor_split = params.tensor_split; - - const int n_layer = hparams.n_layer; - const int n_gpu_layers = this->n_gpu_layers(); - - const bool use_mmap_buffer = true; - - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", - __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - - // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (const auto & dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); - // add CPU buffer types as a fallback - buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); - } - - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - - // calculate the split points - bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); - std::vector splits(n_devices()); - if (all_zero) { - // default split, by free memory - for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i].dev; - size_t total; - size_t free; - ggml_backend_dev_memory(dev, &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - splits[i] = free; - } - } else { - std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); - } - - // sum and normalize the splits to get the split points - float split_sum = 0.0f; - for (size_t i = 0; i < n_devices(); ++i) { - split_sum += splits[i]; - splits[i] = split_sum; - } - for (size_t i = 0; i < n_devices(); ++i) { - splits[i] /= split_sum; - } - - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); - auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); - if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); - return {cpu_dev, &pimpl->cpu_buft_list}; - } - const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu).dev; - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); - return {dev, &pimpl->gpu_buft_list.at(dev)}; - }; - - // assign the input layer - // there is very little benefit to offloading the input layer, so always keep it on the CPU - pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - - // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { - pimpl->dev_layer[il] = get_layer_buft_list(il); +const ggml_tensor * llama_model::get_tensor(const char * name) const { + auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == tensors_by_name.end()) { + return nullptr; } - // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); - - const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; - const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; - const auto TENSOR_SKIP_IF_VIRTUAL = llama_model_loader::TENSOR_SKIP_IF_VIRTUAL; - - // create tensors for the weights - { - // note: cast to int64_t since we will use these for the tensor dimensions - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int64_t n_embd_head_k = hparams.n_embd_head_k(); - const int64_t n_embd_head_v = hparams.n_embd_head_v(); - const int64_t n_ff = hparams.n_ff(); - const int64_t n_embd_gqa = n_embd_v_gqa; - const int64_t n_vocab = vocab.n_tokens(); - const int64_t n_token_types = vocab.n_token_types(); - const int64_t n_rot = hparams.n_rot(); - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ctx_train = hparams.n_ctx_train; - - if (n_expert > 0 && hparams.n_expert_used == 0) { - throw std::runtime_error("model has expert layers but no expert layers are used"); - } - - auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; - return ml.create_tensor( - hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, - tn, ne, flags); - }; - - layers.resize(n_layer); - - // TODO: move to a separate function - const auto tn = LLM_TN(arch); - - // helper: try merged gate_up_exps first, fall back to separate gate and up - auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { - layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); - if (layer.ffn_gate_up_exps == nullptr) { - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); - } - }; - - // helper: try to load merged qkv first, fall back to separate q, k, v - auto create_tensor_qkv = [&](llama_layer & layer, int bid, - int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, - int flags) { - const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - if (layer.wqkv) { - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); - } - }; - - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_REFACT: - case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_expert == 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } - } - } break; - case LLM_ARCH_LLADA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = - create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock - layer.wq = - create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false - layer.wo = - create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, - TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // optional MLP bias - layer.ffn_gate_b = - create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = - create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - } - } - break; - case LLM_ARCH_LLADA_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LLAMA4: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - if (is_moe_layer) { - const int64_t n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert - const int64_t n_ff_shexp = n_ff_exp; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } else { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_DECI: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_ff = hparams.n_ff(i); - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_kv = hparams.n_head_kv(i); - - if (n_head_kv == 0 && n_head > 0) { - // linear attention for DeciLMCausalModel - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - else if (n_head_kv > 0) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - } - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (n_ff > 0) { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_ff > 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINICPM3: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_GROK: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - if (!layer.ffn_post_norm) { - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } - } break; - case LLM_ARCH_DBRX: - { - if (n_expert == 0) { - throw std::runtime_error("DBRX model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_BAICHUAN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_FALCON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_STARCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - // needs to be on GPU - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_JINA_BERT_V3: - { - if (n_token_types == 0) { - throw std::runtime_error(arch_name() + " model needs to define token type count"); - } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_BERT) { - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - } - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_NOMIC_BERT) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - } - } - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_MODERN_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for(int i = 0; i < n_layer; ++i) { - auto& layer = layers[i]; - - if ( i != 0 ) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - } else{ - // layer 0 uses identity - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); - - } break; - case LLM_ARCH_NEO_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + return it->second; +} - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_EUROBERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_seq = cparams.n_ctx_seq; - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; // JinaBertLayer - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); - ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); - const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; - - GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); - layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_BLOOM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_MPT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - - // FIXME test-llama-archs crashes if q_norm is created - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - // AWQ ScaleActivation layer - layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_STABLELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); - } - } break; - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_DREAM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } - } break; - case LLM_ARCH_QWEN3: - case LLM_ARCH_QWEN3VL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output rerank head - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_RND1: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_PHI2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_PHI3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PHIMOE: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PLAMO: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PLAMO2: - { - // mamba parameters - const uint32_t d_conv = hparams.ssm_d_conv; - const uint32_t d_state = hparams.ssm_d_state; - const uint32_t num_heads = hparams.ssm_dt_rank; - const uint32_t intermediate_size = hparams.ssm_d_inner; - const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); - - // attention parameters - const uint32_t qk_dim = hparams.n_embd_head_k(); - const uint32_t v_dim = hparams.n_embd_head_v(); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (is_mamba_layer) { - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); - - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); - - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); - } else { - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t q_num_heads = num_attention_heads; - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t k_num_heads = num_key_value_heads; - const int64_t v_num_heads = num_key_value_heads; - const int64_t q_proj_dim = q_num_heads * qk_dim; - const int64_t k_proj_dim = k_num_heads * qk_dim; - const int64_t v_proj_dim = v_num_heads * v_dim; - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); - } - - // All layers have post-attention norm, FFN norm, and FFN tensors - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - } - } break; - case LLM_ARCH_PLAMO3: - { - const int64_t head_dim_q = hparams.n_embd_head_k(); - const int64_t head_dim_v = hparams.n_embd_head_v(); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t q_proj_dim = num_attention_heads * head_dim_q; - const int64_t k_proj_dim = num_key_value_heads * head_dim_q; - const int64_t v_proj_dim = num_key_value_heads * head_dim_v; - const int64_t n_ff_cur = hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), - {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); - } - } break; - case LLM_ARCH_GPT2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_CODESHELL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if tok embd is NULL, init from output - if (tok_embd == NULL) { - tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ORION: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_INTERNLM2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GEMMA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3: - case LLM_ARCH_GEMMA_EMBEDDING: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // Dense linear weights - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); - dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); - - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3N: - { - const int64_t n_altup = hparams.n_altup; - const int64_t laurel_rank = hparams.laurel_rank; - const int64_t n_embd_altup = hparams.n_embd_altup; - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - - per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - // altup & laurel - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); - layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); - layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); - layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); - layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); - layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); - layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); - layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA4: - { - const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; - const int64_t n_ff_exp = hparams.n_ff_exp; - - if (n_embd_head_k != n_embd_head_v) { - throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); - } - if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { - throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); - } - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - if (n_embd_per_layer > 0) { - per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); - } - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - int rope_freqs_flag = 0; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const int64_t n_head = hparams.n_head(i); - const int64_t n_embd_head = hparams.n_embd_head_k(i); - const int64_t n_embd_k = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v = hparams.n_embd_v_gqa(i); - const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); - - if (!hparams.is_swa(i)) { - // full_attention layers use rope_freqs for proportional rope - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); - rope_freqs_flag = TENSOR_DUPLICATED; - } - - // handle use_double_wide_mlp - int64_t n_ff_cur = hparams.n_ff(i); - - // for expert layers, we use normal FFN as shared expert (same as python code) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - // MoE router - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - bool has_expert = layer.ffn_gate_inp != nullptr; - - // norm - if (has_expert) { - layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); - - layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); - layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); - layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); - - // MoE FFN - layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - - // per-expert scale will be loaded as down_exps_s at the end of the current switch case - } - - // per-layer embeddings - if (n_embd_per_layer > 0) { - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - } - } - } break; - case LLM_ARCH_STARCODER2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional bias tensors - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); - } - } break; - case LLM_ARCH_MAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - if (2 * n_embd != d_inner) { - throw std::runtime_error("only an expansion factor of 2 is supported for now"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_MAMBA2: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_JAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head_kv = hparams.n_head_kv(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); - - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_head_kv == 0) { - // Mamba layer - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // Attention layers - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - - if (layer.ffn_gate_inp) { - // MoE - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } else { - // FFN (no MoE) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - // feed forward (w/ optional biases) - if (n_expert > 0) { - // MoE FFN - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } else { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_XVERSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COMMAND_R: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_layer >= 64){ - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - } - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COHERE2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, - TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - } - } - break; - case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OLMO2: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_SEED_OSS: - { - const uint32_t head_dim = hparams.n_embd_head_k(); - const int64_t n_qo_dim = n_head * head_dim; - const int64_t n_kv_dim = n_head_kv * head_dim; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - - case LLM_ARCH_OLMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_OPENELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; - const int64_t n_ff = hparams.n_ff(i); - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GPTNEOX: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ARCTIC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_DEEPSEEK: - { - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_MISTRAL4: - { - const bool is_mla = hparams.is_mla(); - - // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; - GGML_ASSERT(n_embd_head_qk_nope >= 1); - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (q_lora_rank > 0) { - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - } - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (q_lora_rank > 0) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - } - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); - - // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - if (is_mla) { - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); - } else { - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_DEEPSEEK2OCR: - { - // similar to deepseek2, but without MLA - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // norm - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_PLM: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - const int64_t kv_lora_rank = hparams.n_lora_kv; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_BITNET: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_T5: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // n_layer: number of encoder_layers - // dec_n_layer: number of decoder_layers - const int dec_n_layer = hparams.dec_n_layer; - if (dec_n_layer > n_layer) { - layers.resize(dec_n_layer); - } - - // load encoder layers - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // load decoder layers - for (int i = 0; i < dec_n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); - // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = create_tensor( - tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - - layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_T5ENCODER: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_JAIS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_JAIS2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // attention biases - all have shape n_embd (output dimension of projections) - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - // Jais-2 uses simple MLP (no gate) with biases - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_CHATGLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GLM4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_GLM4_MOE: - { - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_expert_shared = hparams.n_expert_shared; - - GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); - GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - // Load ALL tensors including NextN layer to satisfy total tensor count - // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); - - // GLM-style attention with bias terms - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); - - // K/Q norm tensors (optional for GLM-4.5 355B variant) - layer.attn_q_norm = create_tensor( - tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - layer.attn_k_norm = create_tensor( - tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); - - // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead - // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE - const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); - - if (use_moe) { - // MoE layers - layer.ffn_gate_inp = - create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor( - tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - layer.ffn_down_exps = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); - layer.ffn_up_exps = create_tensor( - tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - - // Shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor( - tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - layer.ffn_down_shexp = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); - layer.ffn_up_shexp = create_tensor( - tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - } - } else { - // Dense layers (first k layers) - GLM uses separate gate/up projections - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } - break; - case LLM_ARCH_GLM_DSA: - { - const bool is_mla = hparams.is_mla(); - if (!is_mla) { - throw std::runtime_error("GLM_DSA architecture requires MLA"); - } - - // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later - flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); - - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); - - // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_NEMOTRON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // all blocks use the attn norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else if (hparams.n_ff(i) == 0) { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - if (n_expert != 0) { - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); - - // MoE branch - layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); - } - } - } - } break; - case LLM_ARCH_EXAONE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_EXAONE4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; - const int64_t head_dim = hparams.n_embd_head_k(); - const int64_t n_qo_dim = n_head * head_dim; - const int64_t n_kv_dim = n_head_kv * head_dim; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end - if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); - - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_RWKV6: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); - } - - } break; - case LLM_ARCH_RWKV6QWEN2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int n_head_kv = hparams.n_head_kv(); - int attn_key_value_size; - if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { - attn_key_value_size = attn_hidden_size; - } else { - attn_key_value_size = n_head_kv * head_size; - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - // optional bias tensors - layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); - - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_RWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); - - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - } - - } break; - case LLM_ARCH_ARWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); - - try { - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - } catch(std::runtime_error & e) { - // ARWKV models may not have gate tensors - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - } - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - } break; - case LLM_ARCH_CHAMELEON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); - - // posnet - { - const int64_t n_embd = hparams.posnet.n_embd; - - for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { - auto & layer = layers[i].posnet; - - // posnet: - // - // - resnet - // - resnet - // - attn - // - resnet - // - resnet - // - norm - // - switch (i) { - case 0: - case 1: - case 3: - case 4: - { - layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); - layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); - - layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); - - layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); - layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); - - layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); - } break; - case 2: - { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - - layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); - - layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); - - layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); - - layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); - } break; - case 5: - { - layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - } break; - default: GGML_ABORT("unknown posnet layer"); - }; - } - } - - GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); - - // convnext - { - const int64_t n_embd = hparams.convnext.n_embd; - - for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { - auto & layer = layers[i].convnext; - - layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); - layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); - - layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); - - layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); - layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); - - layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); - layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); - - layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - } - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); - } break; - case LLM_ARCH_BAILINGMOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers - const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); - } - } - } break; - case LLM_ARCH_DOTS1: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_ARCEE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_AFMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // dual attention normalization - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - // attention projections - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // Q/K normalization - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - // attention gating - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - - // dual ffn normalization - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - if (static_cast(i) >= hparams.n_layer_dense_lead) { - // MoE layers - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - - // grouped expert weights - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - } - } else { - // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - case LLM_ARCH_PADDLEOCR: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers - int n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert (if present) - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - } - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common - const int64_t hidden_size = hparams.n_embd; // hidden_size - - // mamba2 Mixer SSM params - const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size - const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups - const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size - const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand - const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads - const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; - - // attn params - const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head - const int64_t attn_num_key_value_head = hparams.n_head_kv(0); - - // ffn params - const int64_t ffn_intermediate_size = hparams.n_ff(0); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - /*SSM LAYERS*/ - // ssm in - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); - // ssm 1d conv - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); - // ssm_dt - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); - // ssm_norm - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); - - /*ATTENTION LAYERS*/ - // attention layers (with optional bias) - create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); - - - // feed forward (w/ optional biases) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - } - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - } - } break; - case LLM_ARCH_SMOLLM3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); - layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); - layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); - - // ffn/moe is same for transformer and conv layers - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - if (is_moe_layer) { - GGML_ASSERT(n_expert && n_expert_used); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } else { // dense - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // for operator_norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (!hparams.is_recurrent(i)) { - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - - create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } else { - layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); - layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); - layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); - } - } - - // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); - dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); - } break; - case LLM_ARCH_SMALLTHINKER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - } - } break; - case LLM_ARCH_GROVEMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); - GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; - const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); - layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - } - } break; - case LLM_ARCH_APERTUS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // Q and K layernorms for Apertus - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } - } break; - case LLM_ARCH_KIMI_LINEAR: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // Check for KDA specific tensors to determine layer type or if it's a mixed model - // Assuming KDA layer if KDA tensors are present - - // KDA uses head_dim = 128 (from linear_attn_config.head_dim) - const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; - const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; - const int64_t ssm_d_conv = hparams.ssm_d_conv; - - if (hparams.is_recurrent(i)) { - // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) - // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_q_conv) { - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); - } - - // KDA Layer - Conv1d weights may be 3D or 4D - layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_k_conv) { - layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); - } - layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_v_conv) { - layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); - } - - // q, k, v projections - // Python: q_proj, k_proj, v_proj - create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); - - // KDA specific projections - // f_a_proj, f_b_proj - layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim - layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size - - // b_proj (beta mixing coefficient) - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); - - // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_a) { - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - } - - // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); - - // g_a_proj, g_b_proj (output gate) - layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); - layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); - - // o_norm (reusing SSM_NORM) - layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated - - // o_proj - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); - - } else { - // MLA Layer - use MLA-specific head dimensions - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (layer.attn_q_a_norm) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - } else { - // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - } - - // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) - // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 - const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); - // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), - {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - if (!layer.wkv_b) { // MLA KV cache enabled - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); - } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); - } - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // MoE intermediate size (different from dense FFN) - const int64_t n_ff_exp = hparams.n_ff_exp; - - // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE - // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE - if (i < (int) hparams.n_layer_dense_lead) { - // Dense FFN layer - use normal n_ff - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared experts use moe_intermediate_size * num_shared_experts - // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 - // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] - const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); - - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } - } - } break; - case LLM_ARCH_COGVLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // weight tensors - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - // Calculate projection sizes - const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; - const int64_t ba_dim = n_v_heads * 2; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - // note: ssm_in is used by legacy GGUF - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - } - } break; - case LLM_ARCH_QWEN35MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - } - } break; - case LLM_ARCH_QWEN35: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_MIMO2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - uint32_t n_head = hparams.n_head(i); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // non-MoE branch - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - // MoE branch - int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_STEP35: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor - // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. - uint32_t n_rot_max = 0; - for (int i = 0; i < n_layer; ++i) { - n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); - } - if (n_rot_max == 0) { - n_rot_max = n_rot; - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const uint32_t n_head_l = hparams.n_head(i); - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); - - // optional rope factors (llama3) / longrope tensors - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); - - // head-wise attention gate (Step35 self_attn.g_proj) - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // dense MLP (leading dense blocks) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - // MoE routed experts + selection bias (router_bias) - const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - // shared expert MLP - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MAINCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - default: - throw std::runtime_error("unknown architecture"); - } - - // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) - // this avoids having to add scale loading to every architecture - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // attention weight scales (per-tensor, shape {1}) - if (!layer.wq_s && layer.wq) { - layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wk_s && layer.wk) { - layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wv_s && layer.wv) { - layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wo_s && layer.wo) { - layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_s && layer.wqkv) { - layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_gate_s && layer.wqkv_gate) { - layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // dense FFN weight scales (per-tensor, shape {1}) - if (!layer.ffn_gate_s && layer.ffn_gate) { - layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_s && layer.ffn_down) { - layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_s && layer.ffn_up) { - layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { - layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { - layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { - layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // MoE expert weight scales (per-expert, shape {n_expert}) - if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { - layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { - layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { - layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - - // recurrent / linear-attention weight scales (per-tensor, shape {1}) - if (!layer.ssm_in_s && layer.ssm_in) { - layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_out_s && layer.ssm_out) { - layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_alpha_s && layer.ssm_alpha) { - layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_beta_s && layer.ssm_beta) { - layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // input scales - if (!layer.wq_in_s && layer.wq) { - layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wk_in_s && layer.wk) { - layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wv_in_s && layer.wv) { - layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wo_in_s && layer.wo) { - layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_in_s && layer.wqkv) { - layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { - layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_in_s && layer.ffn_gate) { - layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_in_s && layer.ffn_down) { - layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_in_s && layer.ffn_up) { - layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { - layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { - layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { - layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { - layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { - layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { - layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_in_in_s && layer.ssm_in) { - layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_out_in_s && layer.ssm_out) { - layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { - layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_beta_in_s && layer.ssm_beta) { - layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - } - } - - ml.done_getting_tensors(); - - // populate tensors_by_name - for (auto & [_, ctx_ptr] : ml.ctx_map) { - for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - - ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); - pimpl->mappings.reserve(ml.mappings.size()); - - // create the backend buffers - std::vector> ctx_buf_maps; - ctx_buf_maps.reserve(ml.ctx_map.size()); - - // Ensure we have enough capacity for the maximum backend buffer we will potentially create - const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); - pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - - for (auto & [buft, ctx_ptr] : ml.ctx_map) { - ggml_context * ctx = ctx_ptr.get(); - - // skip contexts without tensors - if (ggml_get_first_tensor(ctx) == nullptr) { - continue; - } - - llama_buf_map buf_map; - buf_map.reserve(n_max_backend_buffer); - - // check if it is possible to use buffer_from_host_ptr with this buffer type - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - // FIXME: workaround for CPU backend buft having a NULL device - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!dev) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - } - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; - bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - - std::vector bufs; - if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { - GGML_ASSERT(!ml.no_alloc); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - // only the mmap region containing the tensors in the model is mapped to the backend buffer - // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, - // then we could just use metal for all layers - // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size - void * addr = nullptr; - size_t first, last; // NOLINT - ml.get_mapping_range(&first, &last, &addr, idx, ctx); - if (first >= last) { - continue; - } - const size_t max_size = ggml_get_max_tensor_size(ctx); - ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - bufs.emplace_back(buf); - buf_map.emplace(idx, buf); - } - } else { - ggml_backend_buffer_t buf; - if (ml.no_alloc) { - buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them - } - } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer - } - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - if (use_mlock && ggml_backend_buffer_is_host(buf)) { - pimpl->mlock_bufs.emplace_back(new llama_mlock); - auto & mlock_buf = pimpl->mlock_bufs.back(); - mlock_buf->init (ggml_backend_buffer_get_base(buf)); - mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); - } - bufs.emplace_back(buf); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - buf_map.emplace(idx, buf); - } - } - - for (auto & buf : bufs) { - // indicate that this buffer contains weights - // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - } - - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - - ctx_buf_maps.emplace_back(ctx, buf_map); - } - - if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - - int n_repeating = n_gpu; - if (n_repeating > 0) { - LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); - n_repeating--; - } - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; - - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - } - - // print memory requirements per buffer type - for (auto & [_, bufs] : pimpl->ctxs_bufs) { - for (auto & buf: bufs) { - LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", - __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); - } - } - - if (ml.no_alloc) { - return true; - } - - // load tensor data - for (auto & [ctx, buf_map] : ctx_buf_maps) { - if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { - return false; - } - } - - if (use_mmap_buffer) { - for (auto & mapping : ml.mappings) { - pimpl->mappings.emplace_back(std::move(mapping)); - } - } - - return true; -} - -std::string llama_model::arch_name() const { - return llm_arch_name(arch); -} - -std::string llama_model::type_name() const { - return llm_type_name(type); -} - -std::string llama_model::desc() const { - return pimpl->desc_str; -} - -size_t llama_model::size() const { - return pimpl->n_bytes; -} - -size_t llama_model::n_tensors() const { - return tensors_by_name.size(); -} - -size_t llama_model::n_devices() const { - return devices.size(); -} - -const float * llama_model::tensor_split() const { - return params.tensor_split; -} - -uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; -} - -llama_split_mode llama_model::split_mode() const { - return params.split_mode; -} - -std::map llama_model::memory_breakdown() const { - std::map ret; - for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { - if (hparams.no_alloc) { - GGML_ASSERT(bufs.size() == 1); - ggml_backend_buffer_t buf = bufs[0].get(); - GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); - ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); - ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); - } else { - for (const auto & buf : bufs) { - // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); - } - } - } - return ret; -} - -uint64_t llama_model::n_elements() const { - return pimpl->n_elements; -} - -void llama_model::print_info() const { - const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - - auto print_f = [](const std::function & f, uint32_t n) { - bool is_var = false; - - std::vector v; - for (uint32_t i = 0; i < n; ++i) { - v.push_back(f(i)); - if (v[i] != v[0]) { - is_var = true; - } - } - - std::stringstream ss; - - if (is_var) { - ss << "["; - for (uint32_t i = 0; i < n; ++i) { - ss << v[i]; - if (i < n - 1) { - ss << ", "; - } - } - ss << "]"; - } else { - ss << v[0]; - } - - return ss.str(); - }; - - // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - - if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); - LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); - LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); - LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); - LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); - } - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - // MRoPE (Multi-axis Rotary Position Embedding) sections - if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); - } - if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - - size_t i = 0; - for (const auto & label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); - } - } - - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } - - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } - - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } - - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } - - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } - - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } - - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); - } - } - - vocab.print_info(); -} - -ggml_backend_dev_t llama_model::dev_layer(int il) const { - return pimpl->dev_layer.at(il).dev; -} - -ggml_backend_dev_t llama_model::dev_output() const { - return pimpl->dev_output.dev; -} - -template -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx { ggml_init(params) }; - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; - ggml_tensor * op_tensor = fn(ctx.get()); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op_tensor->src[i] != nullptr) { - assert(op_tensor->src[i]->buffer == nullptr); - op_tensor->src[i]->buffer = buf.get(); - } - } - - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - - return op_supported; -} - -template -static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (buft_supported(cur_buft, cur_dev, fn)) { - return cur_buft; - } - } - - throw std::runtime_error(format("no suitable buffer type found")); -} - -ggml_backend_buffer_type_t llama_model::select_buft(int il) const { - return ::select_buft( - *pimpl->dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} - -bool llama_model::has_tensor_overrides() const { - return pimpl->has_tensor_overrides; -} - -const ggml_tensor * llama_model::get_tensor(const char * name) const { - auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == tensors_by_name.end()) { - return nullptr; - } - - return it->second; -} - -float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; -} - -float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; -} - -ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_seq = cparams.n_ctx_seq; - - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } - - if (n_ctx_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } - - return layers[il].rope_short; -} - -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { - llama_memory_i * res; - - switch (arch) { - // Models that need specific instantiation should be handled in the - // switch statement - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_NEO_BERT: - case LLM_ARCH_EUROBERT: - case LLM_ARCH_WAVTOKENIZER_DEC: - case LLM_ARCH_MODERN_BERT: - case LLM_ARCH_GEMMA_EMBEDDING: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - { - res = nullptr; - } break; - // Models that need standard caching should rely on recurrent/hybrid - // checks - default: - { - if (llm_arch_is_recurrent(arch)) { - res = new llama_memory_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max, - nullptr); - } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the - // layer filters, so pick the right one here - llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; - llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; - if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - } - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - // Use hybrid-iswa for hybrid models with SWA - res = new llama_memory_hybrid_iswa( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx_seq, - /* attn_n_ubatch */ cparams.n_ubatch, - /* attn_n_pad */ 1, - /* recurrent_type_r */ GGML_TYPE_F32, - /* recurrent_type_s */ GGML_TYPE_F32, - /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); - } else { - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx_seq, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); - } - } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; - - if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); - } - - return -1; - }; - } - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - nullptr, - reuse); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - 1, - hparams.n_swa, - hparams.swa_type, - nullptr, - nullptr); - } - } - } - } - - return res; -} - -ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { - std::unique_ptr llm; - - switch (arch) { - case LLM_ARCH_LLAMA: - { - llm = std::make_unique>(*this, params); - } break; - case LLM_ARCH_LLAMA4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_LLAMA_EMBED: - { - llm = std::make_unique>(*this, params); - } break; - case LLM_ARCH_MAINCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DECI: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAICHUAN: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_FALCON: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GROK: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STARCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_REFACT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MODERN_BERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEO_BERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EUROBERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BLOOM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MPT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STABLELM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DREAM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LLADA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LLADA_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RND1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2VL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3VL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3VLMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PHI2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PHI3: - case LLM_ARCH_PHIMOE: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_PLAMO: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLAMO2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLAMO3: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_GPT2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CODESHELL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ORION: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_INTERNLM2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MINICPM3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA3: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_GEMMA3N: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STARCODER2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_MAMBA2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAMBA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_XVERSE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COMMAND_R: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COHERE2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DBRX: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OLMO: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OLMO2: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_OLMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OPENELM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GPTNEOX: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARCTIC: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DEEPSEEK: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_DEEPSEEK2OCR: - case LLM_ARCH_GLM_DSA: - case LLM_ARCH_MISTRAL4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CHATGLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GLM4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BITNET: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_T5: - { - switch (params.gtype) { - case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique>(*this, params); - break; - case LLM_GRAPH_TYPE_DEFAULT: - case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique>(*this, params); - break; - default: - GGML_ABORT("invalid graph type"); - }; - } break; - case LLM_ARCH_T5ENCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAIS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAIS2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEMOTRON: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EXAONE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV6: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV6QWEN2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV7: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARWKV7: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MINICPM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CHAMELEON: - { - llm = std::make_unique(*this, params); - } break; + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { + llama_memory_i * res; + + switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: + case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAILINGMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAILINGMOE2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_SEED_OSS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DOTS1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARCEE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_AFMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ERNIE4_5: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ERNIE4_5_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PADDLEOCR: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_SMOLLM3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OPENAI_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_FALCON_H1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } + res = nullptr; } break; - case LLM_ARCH_SMALLTHINKER: + // Models that need standard caching should rely on recurrent/hybrid + // checks + default: { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique> (*this, params); + if (llm_arch_is_recurrent(arch)) { + res = new llama_memory_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max, + nullptr); + } else if (llm_arch_is_hybrid(arch)) { + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { - llm = std::make_unique>(*this, params); + llama_memory_i::layer_reuse_cb reuse = nullptr; + + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { + reuse = [&](int32_t il) { + if (il >= (int32_t) hparams.n_layer_kv_from_start) { + return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + } + + return -1; + }; + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + reuse); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } } - } break; - case LLM_ARCH_GROVEMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_APERTUS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MINIMAX_M2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COGVLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PANGU_EMBED: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3NEXT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN35: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN35MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MISTRAL3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MIMO2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_KIMI_LINEAR: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STEP35: - { - llm = std::make_unique(*this, params); - } break; - default: - GGML_ABORT("fatal error"); + } } + return res; +} + +ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + std::unique_ptr llm = build_arch_graph(params); + // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); @@ -9487,3 +2472,43 @@ ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int } return model->devices[i].dev; } + +// +// llama_model_base +// + +llama_model_base::llama_model_base(const struct llama_model_params & params) : llama_model(params), model(this), tn(model->arch), + TENSOR_DUPLICATED (llama_model_loader::TENSOR_DUPLICATED), + TENSOR_NOT_REQUIRED (llama_model_loader::TENSOR_NOT_REQUIRED), + TENSOR_SKIP (llama_model_loader::TENSOR_SKIP), + TENSOR_SKIP_IF_VIRTUAL(llama_model_loader::TENSOR_SKIP_IF_VIRTUAL) {} + +ggml_tensor * llama_model_base::create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + GGML_ASSERT(ml != nullptr); + return create_tensor(*ml, tn, ne, flags); +} + +void llama_model_base::create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } +} + +void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 5f101bd6374..d63c689185a 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -577,14 +577,8 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; - explicit llama_model(const struct llama_model_params & params); - ~llama_model(); - - void load_stats (llama_model_loader & ml); - void load_arch (llama_model_loader & ml); - void load_hparams(llama_model_loader & ml); - void load_vocab (llama_model_loader & ml); - bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + explicit llama_model(const llama_model_params & params); + virtual ~llama_model(); std::string arch_name() const; std::string type_name() const; @@ -620,21 +614,94 @@ struct llama_model { ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; - // TODO: move this to new llm_arch_model_i interface llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const; - // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; -private: + virtual void load_stats (llama_model_loader & ml) = 0; + virtual void load_hparams(llama_model_loader & ml) = 0; + virtual void load_vocab (llama_model_loader & ml) = 0; + virtual bool load_tensors(llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback + + // model must define these + virtual void load_arch_hparams(llama_model_loader & ml) = 0; + virtual void load_arch_tensors(llama_model_loader & ml) = 0; + virtual std::unique_ptr build_arch_graph(const llm_graph_params & params) const = 0; + +protected: llama_model_params params; struct impl; std::unique_ptr pimpl; }; +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params); +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params); + +// model must inherit from this +struct llama_model_base : public llama_model { + friend struct llama_model; + + llama_model * model; + llama_model_loader * ml = nullptr; + const LLM_TN tn; + + // llama_model_loader is not yet defined at this point, so we will set it after construction + const int TENSOR_DUPLICATED; + const int TENSOR_NOT_REQUIRED; + const int TENSOR_SKIP; + const int TENSOR_SKIP_IF_VIRTUAL; + + explicit llama_model_base(const llama_model_params & params); + virtual ~llama_model_base() = default; + + ggml_tensor * create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); + + // convenience overload of create_tensor that doesn't require llama_model_loader + ggml_tensor * create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + void create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, + int64_t n_ff_, int64_t n_expert_, int flags); + + // helper: try to load merged qkv first, fall back to separate q, k, v + void create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags); + + void load_stats (llama_model_loader & ml) override; + void load_hparams(llama_model_loader & ml) override; + void load_vocab (llama_model_loader & ml) override; + bool load_tensors(llama_model_loader & ml) override; + + // model must define these + void load_arch_hparams(llama_model_loader & ml) override = 0; + void load_arch_tensors(llama_model_loader & ml) override = 0; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override = 0; +}; + const char * llm_type_name(llm_type type); +// convenience macro for loading local variables for load_tensors() in llama_model_base +// note: cast to int64_t since we will use these for the tensor dimensions +#define LLAMA_LOAD_LOCALS \ + const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \ + const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ + const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ + const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED(n_embd_k_gqa); \ + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED(n_embd_v_gqa); \ + const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED(n_embd_head_k); \ + const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED(n_embd_head_v); \ + const int64_t n_ff = hparams.n_ff(); GGML_UNUSED(n_ff); \ + const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED(n_embd_gqa); \ + const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED(n_vocab); \ + const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED(n_token_types); \ + const int64_t n_rot = hparams.n_rot(); GGML_UNUSED(n_rot); \ + const int64_t n_expert = hparams.n_expert; GGML_UNUSED(n_expert); \ + const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED(n_expert_used); \ + const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED(n_ctx_train); + // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(const llama_model * model); diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 2f0f70b73b6..43e05c3d56f 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -882,13 +882,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching - llama_model model(llama_model_default_params()); + auto mparams = llama_model_default_params(); + std::unique_ptr model_ptr(llama_model_create(ml, mparams)); - model.load_arch (ml); - model.load_hparams(ml); - model.load_stats (ml); + auto * model = dynamic_cast(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); + } + + model->load_hparams(ml); + model->load_stats (ml); - quantize_state_impl qs(model, params); + quantize_state_impl qs(*model, params); if (params->only_copy) { ftype = ml.ftype; @@ -1023,7 +1028,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } gguf_add_tensor(ctx_outs[i_split].get(), tensor); - metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor); + metadata[i].allows_quantization = tensor_allows_quantization(params, model->arch, tensor); if (metadata[i].allows_quantization) { metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]); @@ -1331,9 +1336,9 @@ void llama_quant_free(quantize_state_impl * qs) { llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { struct llama_model_params mparams = llama_model_default_params(); - auto * model = new llama_model(mparams); - - model->arch = llm_arch_from_string(desc->architecture); + auto arch = llm_arch_from_string(desc->architecture); + auto * model = llama_model_create(arch, mparams); + model->arch = arch; // infer llm_type: only LLM_TYPE_70B matters for quantization logic if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 163f222ef61..f43cf546ca0 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -503,6 +503,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding break; + case LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE: + // Sarvam uses SPM-style BPE (same shape as Gemma4): spaces replaced with U+2581 + // by the normalizer, BPE merges over the whole text on raw UTF-8. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -2005,6 +2013,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "gemma4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; + } else if ( + tokenizer_pre == "sarvam-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE; + escape_whitespaces = true; + clean_spaces = false; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index dd38f45d3a2..8b040b912e2 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -59,6 +59,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index e9c3028585d..dfe30ce8f61 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -71,12 +71,18 @@ bool llama_supports_mlock(void) { } bool llama_supports_gpu_offload(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr || ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr || llama_supports_rpc(); } bool llama_supports_rpc(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_reg_by_name("RPC") != nullptr; } @@ -89,6 +95,10 @@ void llama_backend_init(void) { struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } + + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -111,113 +121,8 @@ int64_t llama_time_us(void) { return ggml_time_us(); } -// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, - const std::string & fname, std::vector & splits, FILE * file, llama_model & model, llama_model_params & params) { - // loading time will be recalculated after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = 0; - time_meas tm(model.t_load_us); - - model.t_start_us = tm.t_start_us; - - try { - llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, - params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); - - ml.print_info(); - - model.hparams.vocab_only = params.vocab_only; - model.hparams.no_alloc = params.no_alloc; - - try { - model.load_arch(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model architecture: " + std::string(e.what())); - } - try { - model.load_hparams(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); - } - if (model.arch == LLM_ARCH_CLIP) { - throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); - } - try { - model.load_vocab(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); - } - - model.load_stats(ml); - model.print_info(); - - if (params.vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return 0; - } - - if (!model.load_tensors(ml)) { - return -2; - } - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); - return -1; - } - - return 0; -} - -static struct llama_model * llama_model_load_from_file_impl( - struct gguf_context * metadata, - llama_model_set_tensor_data_t set_tensor_data, - void * set_tensor_data_ud, - const std::string & path_model, - std::vector & splits, - FILE * file, - struct llama_model_params params) { - { - int n_sources_defined = 0; - if (metadata != nullptr) { - n_sources_defined++; - } - if (!path_model.empty()) { - n_sources_defined++; - } - if (file != nullptr) { - n_sources_defined++; - } - if (n_sources_defined != 1) { - LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); - return nullptr; - } - } - ggml_time_init(); - - if (!params.vocab_only && ggml_backend_reg_count() == 0) { - LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); - return nullptr; - } - - unsigned cur_percentage = 0; - if (params.progress_callback == NULL) { - params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void * ctx) { - unsigned * cur_percentage_p = (unsigned *) ctx; - unsigned percentage = (unsigned) (100 * progress); - while (percentage > *cur_percentage_p) { - *cur_percentage_p = percentage; - LLAMA_LOG_CONT("."); - if (percentage >= 100) { - LLAMA_LOG_CONT("\n"); - } - } - return true; - }; - } - - llama_model * model = new llama_model(params); - +// returns true on success +static bool llama_prepare_model_devices(const llama_model_params & params, llama_model * model) { // create list of devices to use with this model if (params.devices) { if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { @@ -227,7 +132,7 @@ static struct llama_model * llama_model_load_from_file_impl( } if (n_devs == 0) { LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); - return nullptr; + return false; } LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); for (size_t i = 0; i < n_devs; ++i) { @@ -265,7 +170,7 @@ static struct llama_model * llama_model_load_from_file_impl( } if (devs.empty()) { LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); - return nullptr; + return false; } LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); @@ -347,8 +252,7 @@ static struct llama_model * llama_model_load_from_file_impl( } else { if (params.main_gpu >= (int)model->devices.size()) { LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size()); - llama_model_free(model); - return nullptr; + return false; } llama_device main_gpu = model->devices[params.main_gpu]; model->devices.clear(); @@ -365,7 +269,121 @@ static struct llama_model * llama_model_load_from_file_impl( props.memory_free/1024/1024); } - const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params); + return true; +} + +// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback +static std::pair llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, + const std::string & fname, std::vector & splits, FILE * file, llama_model_params & params) { + try { + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, + params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + + ml.print_info(); + std::unique_ptr model_ptr(llama_model_create(ml, params)); + + bool ok = llama_prepare_model_devices(params, model_ptr.get()); + if (!ok) { + return {-1, nullptr}; + } + + auto * model = dynamic_cast(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); + } + + // loading time will be recalculated after the first eval, so + // we take page faults deferred by mmap() into consideration + model->t_load_us = 0; + time_meas tm(model->t_load_us); + + model->t_start_us = tm.t_start_us; + + model->hparams.vocab_only = params.vocab_only; + model->hparams.no_alloc = params.no_alloc; + + try { + model->load_hparams(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); + } + if (model->arch == LLM_ARCH_CLIP) { + throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); + } + try { + model->load_vocab(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); + } + + model->load_stats(ml); + model->print_info(); + + if (params.vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return {0, model_ptr.release()}; + } + + if (!model->load_tensors(ml)) { + return {-2, nullptr}; + } + + return {0, model_ptr.release()}; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); + return {-1, nullptr}; + } +} + +static struct llama_model * llama_model_load_from_file_impl( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, + const std::string & path_model, + std::vector & splits, + FILE * file, + struct llama_model_params params) { + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } + ggml_time_init(); + + if (!params.vocab_only && ggml_backend_reg_count() == 0) { + LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); + return nullptr; + } + + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_CONT("."); + if (percentage >= 100) { + LLAMA_LOG_CONT("\n"); + } + } + return true; + }; + } + + const auto [status, model] = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -374,7 +392,9 @@ static struct llama_model * llama_model_load_from_file_impl( LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); } - llama_model_free(model); + if (model) { + llama_model_free(model); + } return nullptr; } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index eb869814097..2ea226726ad 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -864,6 +864,9 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 +// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +#define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 2790b12111d..602e3176afd 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -1,6 +1,112 @@ #include "models.h" -llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + // Set up interleaved sliding window attention (ISWA) + // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + // Default to sigmoid if not set + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_6B; break; + case 32: type = LLM_TYPE_26B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_afmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // dual attention normalization + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + // attention projections + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Q/K normalization + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // attention gating + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + + // dual ffn normalization + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { + // MoE layers + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + + // grouped expert weights + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + } + } else { + // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_afmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index af44cea6054..136ff702957 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,6 +1,62 @@ #include "models.h" -llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_apertus::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_apertus::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 2e71f5d9e2a..70e86d41130 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,6 +1,51 @@ #include "models.h" -llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Arcee uses the same structure as Llama + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arcee::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_arcee::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index f8ca6aff6ab..d8653a44639 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -1,6 +1,59 @@ #include "models.h" -llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arctic::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_arctic::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 107a3bef8da..79aa8c90899 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -1,7 +1,123 @@ #include "models.h" +void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); + + try { + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + } catch(std::runtime_error & e) { + // ARWKV models may not have gate tensors + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + } + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + +} + +std::unique_ptr llama_model_arwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_arwkv7::llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { +llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 2d0d05df485..4e55290e4e5 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } +} + +void llama_model_baichuan::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_baichuan::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 67a7120d622..030dd4f42a4 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } +} + +std::unique_ptr llama_model_bailingmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 497b4babd0c..e7fe3d5b45a 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -1,6 +1,100 @@ #include "models.h" -llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : +void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 20: type = LLM_TYPE_16B_A1B; break; + case 21: type = LLM_TYPE_16B_A1B; break; + case 32: type = LLM_TYPE_100B_A6B; break; + case 33: type = LLM_TYPE_100B_A6B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); + } + } +} + +std::unique_ptr llama_model_bailingmoe2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 7e046cfd2a4..3c28f419ccf 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,6 +1,83 @@ #include "models.h" -llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 3: + type = LLM_TYPE_17M; break; // bge-micro + case 6: + type = LLM_TYPE_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 71526354ca6..7e8125deec4 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -1,7 +1,54 @@ #include "models.h" +void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bitnet::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_bitnet::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bitnet::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index f3b0999bf54..b600fb0c954 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -1,6 +1,68 @@ #include "models.h" -llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_bloom::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_bloom::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 21deaba1a6d..8510b9e29f8 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -1,8 +1,56 @@ #include "models.h" - #include -llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chameleon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_chameleon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 7d4a43fdca5..e898eff7939 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -1,7 +1,60 @@ #include "models.h" +void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chatglm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_chatglm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 3ceb5835b85..e9e85d96713 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_codeshell::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if tok embd is NULL, init from output + if (tok_embd == NULL) { + tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_codeshell::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index be3eeeddac7..79236121bd5 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : +void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cogvlm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_cogvlm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2.cpp similarity index 60% rename from examples/talk-llama/models/cohere2-iswa.cpp rename to examples/talk-llama/models/cohere2.cpp index 670b08e7d97..12edbae1094 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cohere2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } +} + +std::unique_ptr llama_model_cohere2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index 067961caa08..decb89f547b 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -1,8 +1,48 @@ #include "models.h" +void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_command_r::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_command_r::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) : +llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 0e882721807..bce6b04bcf9 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { +ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + +switch (hparams.n_layer) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; +} + } + +void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_dbrx::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 30272eabd69..9f1a959c32c 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,6 +1,82 @@ #include "models.h" -llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deci::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_deci::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 671b72dfead..c7946059662 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -1,6 +1,77 @@ #include "models.h" -llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : +void llama_model_deepseek::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + + switch (hparams.n_ff_exp) { + case 1408: type = LLM_TYPE_16B; break; + case 1792: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 303fc72c610..1fe54adc13e 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -1,6 +1,149 @@ #include "models.h" -llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : +void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + } + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // (optional) temperature tuning - used by mistral-large + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? + + hparams.f_attn_temp_offset = 0.0f; + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + GGML_ASSERT(n_embd_head_qk_nope >= 1); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (q_lora_rank > 0) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (q_lora_rank > 0) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp new file mode 100644 index 00000000000..f9e4c98785c --- /dev/null +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -0,0 +1,82 @@ +#include "models.h" + +void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2ocr::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek2ocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 5d1750fedda..93cbcf9d931 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -1,6 +1,76 @@ #include "models.h" -llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) : +void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_142B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_dots1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_dots1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 8e7d9ae64c7..60a3f0ec285 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,6 +1,54 @@ #include "models.h" -llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : +void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_dream::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_dream::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dream::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index fc6a3e17a09..2bd01a2c512 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -1,6 +1,10 @@ #include "models.h" -llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : +std::unique_ptr llama_model_ernie4_5_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 033ba409eab..fa989fe92cd 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -1,6 +1,79 @@ #include "models.h" -llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : +void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_ERNIE4_5_MOE) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + } + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_0_3B; break; + case 28: type = LLM_TYPE_21B_A3B; break; + case 54: type = LLM_TYPE_300B_A47B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_ernie4_5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_ernie4_5::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index 43fff4daf3a..ddf13c3028f 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -1,6 +1,41 @@ #include "models.h" -llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } +} + +void llama_model_eurobert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_eurobert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_eurobert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 7b88a31d39d..54bb3ca86b3 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -1,6 +1,117 @@ #include "models.h" -llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) : +void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; + const int64_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_exaone_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 4f845bf4106..75d5f60631c 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : +void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_exaone::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 34bee3b8fe9..5506e76424d 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,7 +1,71 @@ #include "models.h" +void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { + if (hparams.n_layer == 64) { // 32B + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_1_2B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_exaone4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : +llama_model_exaone4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); @@ -108,5 +172,5 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ } // Explicit template instantiations -template struct llm_build_exaone4; -template struct llm_build_exaone4; +template struct llama_model_exaone4::graph; +template struct llama_model_exaone4::graph; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 05accf90fad..d353befdb8e 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -1,6 +1,115 @@ #include "models.h" -llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : +void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { + // Common parameters + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + case 72: + type = LLM_TYPE_34B; break; + default: + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon_h1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + // ssm_norm + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_falcon_h1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 2f65fa56e1f..75f2cfef560 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_falcon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index b6de9551c52..4e07f5f2bda 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -1,6 +1,78 @@ #include "models.h" -llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : +void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.causal_attn = false; // embeddings do not use causal attention + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); + +} + +void llama_model_gemma_embedding::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma_embedding::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma_embedding::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 09d2ff8bae7..06731670007 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,6 +1,44 @@ #include "models.h" -llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2.cpp similarity index 53% rename from examples/talk-llama/models/gemma2-iswa.cpp rename to examples/talk-llama/models/gemma2.cpp index 0ef07df8d01..6255bf740fc 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; // default value of gemma 2 + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index 0da4af21c17..ee510fe38b0 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -1,7 +1,87 @@ #include "models.h" +void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_270M; break; + case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 + case 34: type = LLM_TYPE_4B; break; + case 48: type = LLM_TYPE_12B; break; + case 62: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_gemma3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; @@ -141,5 +221,5 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -template struct llm_build_gemma3; -template struct llm_build_gemma3; +template struct llama_model_gemma3::graph; +template struct llama_model_gemma3::graph; diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n.cpp similarity index 76% rename from examples/talk-llama/models/gemma3n-iswa.cpp rename to examples/talk-llama/models/gemma3n.cpp index f8095417e06..881499b0ca7 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -1,5 +1,86 @@ #include "models.h" +void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { + uint32_t swa_period = 5; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(swa_period); + + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_E2B; break; + case 35: type = LLM_TYPE_E4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma3n::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_altup = hparams.n_altup; + const int64_t laurel_rank = hparams.laurel_rank; + const int64_t n_embd_altup = hparams.n_embd_altup; + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // altup & laurel + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma3n::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { GGML_ASSERT(idx < (int) x->ne[2]); @@ -7,7 +88,7 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } -llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : +llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), n_embd_head(model.hparams.n_embd_head_k()), @@ -229,13 +310,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { +ggml_tensor * llama_model_gemma3n::graph::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; float tok_embd_scale = sqrtf((float) n_embd_altup); @@ -268,7 +349,7 @@ ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { +ggml_tensor * llama_model_gemma3n::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); @@ -291,7 +372,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp // input cur shape: [n_altup, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { +ggml_tensor * llama_model_gemma3n::graph::laurel(ggml_tensor * cur, int il) { ggml_tensor * tmp = cur; tmp = build_lora_mm(model.layers[il].laurel_l, tmp); tmp = build_lora_mm(model.layers[il].laurel_r, tmp); @@ -303,7 +384,7 @@ ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { // input x shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::gaussian_topk(ggml_tensor * x) { ggml_tensor * mean = ggml_mean(ctx0, x); ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), 1.0f / (float) (x->ne[0] - 1))); @@ -318,7 +399,7 @@ ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { // equivalent to compute_router_modalities() in python code // input x shape: [n_embd, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_compute_router_modalities(ggml_tensor * x, int il) { ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); // router_input_scale @@ -330,7 +411,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_predict(ggml_tensor * cur, int il) { ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -355,7 +436,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { // input predictions shape: [n_embd, n_tokens, n_altup] // input activated shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); diff --git a/examples/talk-llama/models/gemma4-iswa.cpp b/examples/talk-llama/models/gemma4.cpp similarity index 62% rename from examples/talk-llama/models/gemma4-iswa.cpp rename to examples/talk-llama/models/gemma4.cpp index c7fb7747414..f45ae4cad59 100644 --- a/examples/talk-llama/models/gemma4-iswa.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -1,5 +1,140 @@ #include "models.h" +void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + } + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr llama_model_gemma4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { GGML_ASSERT(idx < (int) x->ne[2]); @@ -7,7 +142,7 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } -llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : +llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), n_embd_per_layer(model.hparams.n_embd_per_layer) { @@ -157,8 +292,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll cur_moe = build_moe_ffn(cur_moe, nullptr, // gate_inp - nullptr, // up_exps - nullptr, // gate_exps + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, // exp_probs_b (not used for gemma4) n_expert, n_expert_used, @@ -167,8 +302,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, logits, model.layers[il].ffn_gate_up_exps, - nullptr, // up_exps_s - nullptr, // gate_exps_s + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, model.layers[il].ffn_down_exps_s); cur_moe = build_norm(cur_moe, model.layers[il].ffn_post_norm_2, nullptr, @@ -261,7 +396,7 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_per_layer, n_layer, n_tokens] -ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { +ggml_tensor * llama_model_gemma4::graph::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; @@ -299,7 +434,7 @@ ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { // inp_batch shape: [n_embd, n_tokens] // inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) // output shape: [n_embd_per_layer, n_tokens, n_layer] -ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { +ggml_tensor * llama_model_gemma4::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp new file mode 100644 index 00000000000..af2b55ef563 --- /dev/null +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -0,0 +1,155 @@ +#include "models.h" + +void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm_dsa::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 8d4f4a01553..45886b51ac1 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -1,6 +1,139 @@ #include "models.h" -llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm4_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index f0bfda393fa..d6ef76e26d6 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,6 +1,78 @@ #include "models.h" -llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index f8dc53eb723..ba49c31b56b 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -1,6 +1,60 @@ #include "models.h" -llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gpt2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_gpt2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 0016ddede43..33ebe2d8800 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,6 +1,89 @@ #include "models.h" -llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff()) { + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gptneox::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_gptneox::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index e983742bef5..12e4790ae24 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -1,6 +1,137 @@ #include "models.h" -llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : +void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); + + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + // A layer is recurrent IFF the n_head_kv value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_350M; break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 2048: case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + // feed forward (w/ optional biases) + if (n_expert > 0) { + // MoE FFN + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } else { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_granite_hybrid::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -67,7 +198,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, co ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, const llama_model & model, @@ -98,7 +229,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il) { diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp new file mode 100644 index 00000000000..0d89bc1f340 --- /dev/null +++ b/examples/talk-llama/models/granite-moe.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_granite_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 6ea90285225..5e7c7b68181 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,6 +1,93 @@ #include "models.h" -llm_build_granite::llm_build_granite( +void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_granite::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_granite::graph::graph( const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -68,7 +155,7 @@ llm_build_granite::llm_build_granite( ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite::build_attention_layer( +ggml_tensor * llama_model_granite::graph::build_attention_layer( ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -107,7 +194,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( return cur; } -ggml_tensor * llm_build_granite::build_layer_ffn( +ggml_tensor * llama_model_granite::graph::build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index b8f35afdc03..0bc49d00206 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -1,6 +1,89 @@ #include "models.h" -llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grok::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr llama_model_grok::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 151108a2a71..feef815165b 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -1,6 +1,70 @@ #include "models.h" -llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : +void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); + ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grovemoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); + GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; + const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); + layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + } +} + +std::unique_ptr llama_model_grovemoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 1cd85d6d9d4..c137bd37c02 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -1,132 +1,6 @@ #include "models.h" -llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - GGML_ASSERT(n_embd_head == n_rot); - - const bool use_mrope = hparams.use_mrope(); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - // compute Q and K and RoPE them - auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, - n_embd_head, n_head, n_head_kv, il); - - if (use_mrope) { - Qcur = ggml_rope_multi( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_multi( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } else { - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, nullptr, - LLM_NORM_RMS, il); - cb(Kcur, "Kcur_norm", il); - - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, nullptr, - LLM_NORM_RMS, il); - cb(Qcur, "Qcur_norm", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - // feed-forward network (non-MoE) - ggml_tensor * cur_mlp = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_mlp, "ffn_out", il); - - cur = ggml_add(ctx0, cur_mlp, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); +std::unique_ptr llama_model_hunyuan_dense::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); } + diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index ffe1664b0e1..44af42412f7 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -1,6 +1,59 @@ #include "models.h" -llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_A13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + } +} + +std::unique_ptr llama_model_hunyuan_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp new file mode 100644 index 00000000000..5fb9154bec0 --- /dev/null +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -0,0 +1,189 @@ +#include "models.h" + +void llama_model_hunyuan_vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_1_8B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } +} + +std::unique_ptr llama_model_hunyuan_vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 83be2ca0aee..f0c5580a6f4 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -1,6 +1,43 @@ #include "models.h" -llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_internlm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_internlm2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 31101f3c14b..a6451dca095 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -1,6 +1,58 @@ #include "models.h" -llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + /* TODO: add variants */ + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_jais::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 507e04fa4aa..ad59b953e8d 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -1,8 +1,63 @@ #include "models.h" +void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jais2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // JAIS-2 model graph builder // Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings -llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index f82b7795c87..e1b8d137e38 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -1,6 +1,111 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { +void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // Attention layers + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } else { + // FFN (no MoE) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_jamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp new file mode 100644 index 00000000000..4f8866ece4d --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -0,0 +1,66 @@ +#include "models.h" + +void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v2::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jina_bert_v2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp new file mode 100644 index 00000000000..e0527529f56 --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -0,0 +1,69 @@ +#include "models.h" + +void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jina_bert_v3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 58c89c417fc..ecffb105496 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -1,7 +1,175 @@ #include "models.h" - #include "llama-memory-recurrent.h" +void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + if (hparams.is_recurrent(i)) { + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } +} + +std::unique_ptr llama_model_kimi_linear::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { @@ -63,7 +231,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs); } -llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : +llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index eb8ec3c803a..df6a8028736 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -1,10 +1,94 @@ #include "models.h" - #include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" +void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + hparams.n_layer_dense_lead = hparams.n_layer; + switch (hparams.n_ff()) { + case 4608: type = LLM_TYPE_350M; break; + case 6912: type = LLM_TYPE_700M; break; + case 8192: type = LLM_TYPE_1_2B; break; + case 10752: type = LLM_TYPE_2_6B; break; + default: type = LLM_TYPE_UNKNOWN; + } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } +} + +void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr llama_model_lfm2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : +llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { using inp_hybrid_type = std::conditional_t; using inp_attn_type = std::conditional_t; @@ -187,5 +271,5 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ } // Explicit template instantiations -template struct llm_build_lfm2; -template struct llm_build_lfm2; +template struct llama_model_lfm2::graph; +template struct llama_model_lfm2::graph; diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp new file mode 100644 index 00000000000..12a66c05c7d --- /dev/null +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -0,0 +1,85 @@ +#include "models.h" +#include "../llama-memory-hybrid-iswa.h" +#include "../llama-memory-hybrid.h" + +void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr llama_model_lfm2moe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index c756d6fde5f..b60f67f6c4b 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -1,6 +1,56 @@ #include "models.h" -llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention + hparams.causal_attn = false; + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_llada_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_llada_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 501df3c7eaf..fa21c5fe32c 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -1,6 +1,72 @@ #include "models.h" -llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion + switch (hparams.n_layer) { + case 32: + type = LLM_TYPE_8B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_llada::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = + create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock + layer.wq = + create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false + layer.wo = + create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, + TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // optional MLP bias + layer.ffn_gate_b = + create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = + create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_llada::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_llada::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // LLaDA is similar to LLaMA but uses non-causal attention for diffusion const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/llama-embed.cpp b/examples/talk-llama/models/llama-embed.cpp new file mode 100644 index 00000000000..0699e744461 --- /dev/null +++ b/examples/talk-llama/models/llama-embed.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_llama_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique>(*this, params); +} + diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8d478dc6747..8ddb5936820 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -1,7 +1,102 @@ #include "models.h" +void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B + case 30: type = LLM_TYPE_256M; break; // smoldocling 256M + // granite uses a vocab with len 49152 + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } +} + +void llama_model_llama::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_llama::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique>(*this, params); +} + template -llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_llama::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -149,5 +244,5 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_build_forward_expand(gf, cur); } -template struct llm_build_llama; -template struct llm_build_llama; +template struct llama_model_llama::graph; +template struct llama_model_llama::graph; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 4e4bfb43f33..899611d53f6 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -1,7 +1,109 @@ #include "models.h" +void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa == 0) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + } else { + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; + hparams.f_attn_temp_offset = 1.0f; + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + switch (hparams.n_expert) { + case 0: { + // MobileLLM (no MoE) + switch (hparams.n_embd) { + case 2048: type = LLM_TYPE_140M; break; + case 4096: type = LLM_TYPE_360M; break; + case 6144: type = LLM_TYPE_950M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + hparams.use_kq_norm = type != LLM_TYPE_17B_128E; +} + +void llama_model_llama4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + const int64_t n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_llama4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_llama4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -167,5 +269,5 @@ llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_gr } // Explicit template instantiations -template struct llm_build_llama4; -template struct llm_build_llama4; +template struct llama_model_llama4::graph; +template struct llama_model_llama4::graph; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 8a76931c007..3dbd82fd362 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_maincoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_maincoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 55fd2e055c4..b7708d7fdd1 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -1,6 +1,90 @@ #include "models.h" -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { +void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr llama_model_mamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -51,4 +135,3 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } - diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp new file mode 100644 index 00000000000..3277ca53ec4 --- /dev/null +++ b/examples/talk-llama/models/mamba2.cpp @@ -0,0 +1,87 @@ +#include "models.h" + +void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr llama_model_mamba2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp deleted file mode 100644 index 52c6acfe214..00000000000 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include "models.h" - -llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - uint32_t n_head_l = hparams.n_head(il); - uint32_t n_head_kv_l = hparams.n_head_kv(il); - const float freq_base_l = model.get_rope_freq_base(cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - cur = inpL; - - // self_attention - { - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - ggml_tensor * sinks = model.layers[il].attn_sinks; - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, model.layers[il].wo_s, - Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward network - if (model.layers[il].ffn_gate_inp == nullptr) { - // dense branch - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // MoE branch - cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - hparams.expert_weights_scale, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); - cb(cur, "ffn_moe_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp new file mode 100644 index 00000000000..71996616611 --- /dev/null +++ b/examples/talk-llama/models/mimo2.cpp @@ -0,0 +1,240 @@ +#include "models.h" + +void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + float value_scale = 0.0f; + if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { + hparams.f_attn_value_scale = value_scale; + } + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer - hparams.nextn_predict_layers) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const uint32_t n_nextn = hparams.nextn_predict_layers; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support + const bool is_nextn = (n_nextn > 0) && (static_cast(i) >= n_layer - n_nextn); + const int skip = is_nextn ? TENSOR_SKIP : 0; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, skip); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, skip); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED | skip); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, skip); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | skip); + + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, skip); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, skip); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, skip); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, skip); + } + } +} + +std::unique_ptr llama_model_mimo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float v_scale = hparams.f_attn_value_scale; + + // The last hparams.nextn_predict_layers blocks are MTP heads, currently inactive + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + uint32_t n_head_l = hparams.n_head(il); + uint32_t n_head_kv_l = hparams.n_head_kv(il); + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + if (model.layers[il].wqkv) { + // Fused qkv_proj - Q/K share head_dim_k, V uses head_dim_v + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + + const size_t row_k = ggml_row_size(qkv->type, n_embd_head_k); + const size_t row_v = ggml_row_size(qkv->type, n_embd_head_v); + const size_t row_full = qkv->nb[1]; + const size_t k_off = row_k * n_head_l; + const size_t v_off = k_off + row_k * n_head_kv_l; + + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_l, n_tokens, row_k, row_full, 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv_l, n_tokens, row_k, row_full, k_off); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv_l, n_tokens, row_v, row_full, v_off); + } else { + // Split path + Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * sinks = model.layers[il].attn_sinks; + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + cb(cur, "attn_out", il); + + if (v_scale) { + cur = ggml_scale(ctx0, cur, v_scale); + cb(cur, "attn_out_scaled", il); + } + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp new file mode 100644 index 00000000000..966d3af615c --- /dev/null +++ b/examples/talk-llama/models/minicpm.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { + // Backward-compatible defaults for older MiniCPM GGUFs + hparams.f_embedding_scale = 12.0f; + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Optional KV reads, override defaults if present in newer GGUF exports + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_minicpm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index bf12ab73c74..ff5eb6ffa5f 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -1,6 +1,66 @@ #include "models.h" -llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_minicpm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index b809b79f2b9..0dee8934692 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_230B_A10B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minimax_m2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } +} + +std::unique_ptr llama_model_minimax_m2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index b5ae72a2ee1..708da49af1f 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -1,6 +1,96 @@ #include "models.h" -llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + hparams.f_attn_temp_offset = 0.0f; + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mistral3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_mistral3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mistral4.cpp b/examples/talk-llama/models/mistral4.cpp new file mode 100644 index 00000000000..3d9190650e3 --- /dev/null +++ b/examples/talk-llama/models/mistral4.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_mistral4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 94991c55fe8..6d5f18a8e20 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -2,6 +2,7 @@ #include "llama-model.h" #include "llama-graph.h" +#include "llama-model-loader.h" // note: almost all graphs require at least sqrtf, so include cmath globally #include @@ -110,611 +111,1750 @@ struct llm_build_rwkv7_base : public llm_graph_context { // models // -struct llm_build_afmoe : public llm_graph_context { - llm_build_afmoe(const llama_model & model, const llm_graph_params & params); +struct llama_model_llama : public llama_model_base { + llama_model_llama(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_apertus : public llm_graph_context { - llm_build_apertus(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llama4 : public llama_model_base { + llama_model_llama4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llama_embed : public llama_model_llama { + llama_model_llama_embed(const struct llama_model_params & params) : llama_model_llama(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_llama + + template + using graph = llama_model_llama::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arctic : public llm_graph_context { - llm_build_arctic(const llama_model & model, const llm_graph_params & params); + +struct llama_model_maincoder : public llama_model_base { + llama_model_maincoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arwkv7 : public llm_build_rwkv7_base { - llm_build_arwkv7(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deci : public llama_model_base { + llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_baichuan : public llm_graph_context { - llm_build_baichuan(const llama_model & model, const llm_graph_params & params); + +struct llama_model_baichuan : public llama_model_base { + llama_model_baichuan(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bailingmoe2 : public llm_graph_context { - llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_falcon : public llama_model_base { + llama_model_falcon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bailingmoe : public llm_graph_context { - llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_grok : public llama_model_base { + llama_model_grok(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bert : public llm_graph_context { - llm_build_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_starcoder : public llama_model_base { + llama_model_starcoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bitnet : public llm_graph_context { - llm_build_bitnet(const llama_model & model, const llm_graph_params & params); + +struct llama_model_refact : public llama_model_base { + llama_model_refact(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bloom : public llm_graph_context { - llm_build_bloom(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bert : public llama_model_base { + llama_model_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_chameleon : public llm_graph_context { - llm_build_chameleon(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jina_bert_v2 : public llama_model_base { + llama_model_jina_bert_v2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_chatglm : public llm_graph_context { - llm_build_chatglm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jina_bert_v3 : public llama_model_base { + llama_model_jina_bert_v3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_codeshell : public llm_graph_context { - llm_build_codeshell(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nomic_bert : public llama_model_base { + llama_model_nomic_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_cogvlm : public llm_graph_context { - llm_build_cogvlm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nomic_bert_moe : public llama_model_base { + llama_model_nomic_bert_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_cohere2_iswa : public llm_graph_context { - llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_modern_bert : public llama_model_base { + llama_model_modern_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_command_r : public llm_graph_context { - llm_build_command_r(const llama_model & model, const llm_graph_params & params); + +struct llama_model_neo_bert : public llama_model_base { + llama_model_neo_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dbrx : public llm_graph_context { - llm_build_dbrx(const llama_model & model, const llm_graph_params & params); + +struct llama_model_eurobert : public llama_model_base { + llama_model_eurobert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deci : public llm_graph_context { - llm_build_deci(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bloom : public llama_model_base { + llama_model_bloom(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deepseek2 : public llm_graph_context { - llm_build_deepseek2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mpt : public llama_model_base { + llama_model_mpt(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params); + +struct llama_model_stablelm : public llama_model_base { + llama_model_stablelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dots1 : public llm_graph_context { - llm_build_dots1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen : public llama_model_base { + llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dream : public llm_graph_context { - llm_build_dream(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2 : public llama_model_base { + llama_model_qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_ernie4_5 : public llm_graph_context { - llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dream : public llama_model_base { + llama_model_dream(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_ernie4_5_moe : public llm_graph_context { - llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada : public llama_model_base { + llama_model_llada(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_paddleocr : public llm_graph_context { - llm_build_paddleocr(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada_moe : public llama_model_base { + llama_model_llada_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_exaone4 : public llm_graph_context { - llm_build_exaone4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rnd1 : public llama_model_base { + llama_model_rnd1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_exaone : public llm_graph_context { - llm_build_exaone(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2vl : public llama_model_base { + llama_model_qwen2vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_exaone_moe : public llm_graph_context { - llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2moe : public llama_model_base { + llama_model_qwen2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_falcon : public llm_graph_context { - llm_build_falcon(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3 : public llama_model_base { + llama_model_qwen3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_falcon_h1 : public llm_build_mamba_base { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3moe : public llama_model_base { + llama_model_qwen3moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma2_iswa : public llm_graph_context { - llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vl : public llama_model_base { + llama_model_qwen3vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_gemma3 : public llm_graph_context { - llm_build_gemma3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vlmoe : public llama_model_base { + llama_model_qwen3vlmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma3n_iswa : public llm_graph_context { - const llama_model & model; - const int64_t n_embd_head; - const int64_t n_embd_altup; - const int64_t n_altup; - const int i_altup_act; - const int n_layer_sparsity = 10; // number of layers using activation sparsity - const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) +struct llama_model_phi2 : public llama_model_base { + llama_model_phi2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); - ggml_tensor * calc_magnitude(ggml_tensor * x); - // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] - ggml_tensor * build_inp_per_layer(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); +struct llama_model_phi3 : public llama_model_base { + llama_model_phi3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * gaussian_topk(ggml_tensor * x); - ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); - ggml_tensor * altup_predict(ggml_tensor * cur, int il); - ggml_tensor * laurel(ggml_tensor * cur, int il); - ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma4_iswa : public llm_graph_context { - const llama_model & model; - const int64_t n_embd_per_layer; +struct llama_model_phimoe : public llama_model_base { + llama_model_phimoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params); + template + using graph = llama_model_phi3::graph; - // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] - ggml_tensor * build_inp_per_layer(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma_embedding : public llm_graph_context { - llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo : public llama_model_base { + llama_model_plamo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma : public llm_graph_context { - llm_build_gemma(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo2 : public llama_model_base { + llama_model_plamo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); + ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, + const llama_model & model, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_glm4 : public llm_graph_context { - llm_build_glm4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo3 : public llama_model_base { + llama_model_plamo3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gpt2 : public llama_model_base { + llama_model_gpt2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gpt2 : public llm_graph_context { - llm_build_gpt2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_codeshell : public llama_model_base { + llama_model_codeshell(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gptneox : public llm_graph_context { - llm_build_gptneox(const llama_model & model, const llm_graph_params & params); + +struct llama_model_orion : public llama_model_base { + llama_model_orion(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_granite : public llm_graph_context { - llm_build_granite(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il); +struct llama_model_internlm2 : public llama_model_base { + llama_model_internlm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_granite_hybrid : public llm_build_mamba_base { - llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, - const llama_model & model,const int64_t n_embd_head, const int il); + +struct llama_model_minicpm3 : public llama_model_base { + llama_model_minicpm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_grok : public llm_graph_context { - llm_build_grok(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma : public llama_model_base { + llama_model_gemma(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_grovemoe : public llm_graph_context { - llm_build_grovemoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma2 : public llama_model_base { + llama_model_gemma2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_hunyuan_dense : public llm_graph_context { - llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3 : public llama_model_base { + llama_model_gemma3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3n : public llama_model_base { + llama_model_gemma3n(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_head; + const int64_t n_embd_altup; + const int64_t n_altup; + const int i_altup_act; + const int n_layer_sparsity = 10; // number of layers using activation sparsity + const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) + + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * calc_magnitude(ggml_tensor * x); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + + ggml_tensor * gaussian_topk(ggml_tensor * x); + ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); + ggml_tensor * altup_predict(ggml_tensor * cur, int il); + ggml_tensor * laurel(ggml_tensor * cur, int il); + ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_internlm2 : public llm_graph_context { - llm_build_internlm2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma4 : public llama_model_base { + llama_model_gemma4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + graph(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jais : public llm_graph_context { - llm_build_jais(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma_embedding : public llama_model_base { + llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jais2 : public llm_graph_context { - llm_build_jais2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_starcoder2 : public llama_model_base { + llama_model_starcoder2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jamba : public llm_build_mamba_base { - llm_build_jamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mamba : public llama_model_base { + llama_model_mamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_kimi_linear : public llm_build_delta_net_base { - llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); - std::pair build_kda_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - int il); +struct llama_model_mamba2 : public llama_model_base { + llama_model_mamba2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - std::pair build_kda_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); + using graph = llama_model_mamba::graph; - const llama_model & model; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_lfm2 : public llm_graph_context { - llm_build_lfm2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jamba : public llama_model_base { + llama_model_jamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada : public llm_graph_context { - llm_build_llada(const llama_model & model, const llm_graph_params & params); + +struct llama_model_xverse : public llama_model_base { + llama_model_xverse(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada_moe : public llm_graph_context { - llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_command_r : public llama_model_base { + llama_model_command_r(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_cohere2 : public llama_model_base { + llama_model_cohere2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_dbrx : public llama_model_base { + llama_model_dbrx(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmo : public llama_model_base { + llama_model_olmo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmo2 : public llama_model_base { + llama_model_olmo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmoe : public llama_model_base { + llama_model_olmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_openelm : public llama_model_base { + llama_model_openelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_gptneox : public llama_model_base { + llama_model_gptneox(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_arctic : public llama_model_base { + llama_model_arctic(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_llama : public llm_graph_context { - llm_build_llama(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek : public llama_model_base { + llama_model_deepseek(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_llama4 : public llm_graph_context { - llm_build_llama4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek2 : public llama_model_base { + llama_model_deepseek2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_maincoder : public llm_graph_context { - llm_build_maincoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek2ocr : public llama_model_base { + llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mamba : public llm_build_mamba_base { - llm_build_mamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm_dsa : public llama_model_base { + llama_model_glm_dsa(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mimo2_iswa : public llm_graph_context { - llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mistral4 : public llama_model_deepseek2 { + llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_deepseek2 + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minicpm3 : public llm_graph_context { - llm_build_minicpm3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_chatglm : public llama_model_base { + llama_model_chatglm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minimax_m2 : public llm_graph_context { - llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm4 : public llama_model_base { + llama_model_glm4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mistral3 : public llm_graph_context { - llm_build_mistral3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm4_moe : public llama_model_base { + llama_model_glm4_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_modern_bert : public llm_graph_context { - llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bitnet : public llama_model_base { + llama_model_bitnet(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mpt : public llm_graph_context { - llm_build_mpt(const llama_model & model, const llm_graph_params & params); + +struct llama_model_t5 : public llama_model_base { + llama_model_t5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron : public llm_graph_context { - llm_build_nemotron(const llama_model & model, const llm_graph_params & params); + +struct llama_model_t5encoder : public llama_model_base { + llama_model_t5encoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_t5::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron_h : public llm_build_mamba_base { - llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, int64_t n_embd_head, int il); + +struct llama_model_jais : public llama_model_base { + llama_model_jais(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_neo_bert : public llm_graph_context { - llm_build_neo_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jais2 : public llama_model_base { + llama_model_jais2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_eurobert : public llm_graph_context { - llm_build_eurobert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron : public llama_model_base { + llama_model_nemotron(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_olmo2 : public llm_graph_context { - llm_build_olmo2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron_h : public llama_model_base { + llama_model_nemotron_h(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, + const llama_model & model, int64_t n_embd_head, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmoe : public llm_graph_context { - llm_build_olmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron_h_moe : public llama_model_nemotron_h { + llama_model_nemotron_h_moe(const struct llama_model_params & params) : llama_model_nemotron_h(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_nemotron_h + + using graph = llama_model_nemotron_h::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmo : public llm_graph_context { - llm_build_olmo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone : public llama_model_base { + llama_model_exaone(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openai_moe_iswa : public llm_graph_context { - llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone4 : public llama_model_base { + llama_model_exaone4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone_moe : public llama_model_base { + llama_model_exaone_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_orion : public llm_graph_context { - llm_build_orion(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv6 : public llama_model_base { + llama_model_rwkv6(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_pangu_embedded : public llm_graph_context { - llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv6qwen2 : public llama_model_base { + llama_model_rwkv6qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_phi2 : public llm_graph_context { - llm_build_phi2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv7 : public llama_model_base { + llama_model_rwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_phi3 : public llm_graph_context { - llm_build_phi3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_arwkv7 : public llama_model_base { + llama_model_arwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo2 : public llm_build_mamba_base { - llm_build_plamo2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite : public llama_model_base { + llama_model_granite(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + private: - ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); - ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, - const llama_model & model, int il); + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo : public llm_graph_context { - llm_build_plamo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite_moe : public llama_model_base { + llama_model_granite_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_plamo3 : public llm_graph_context { - llm_build_plamo3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_minicpm : public llama_model_base { + llama_model_minicpm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plm : public llm_graph_context { - llm_build_plm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite_hybrid : public llama_model_base { + llama_model_granite_hybrid(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, + const llama_model & model,const int64_t n_embd_head, const int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2 : public llm_graph_context { - llm_build_qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_chameleon : public llama_model_base { + llama_model_chameleon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2moe : public llm_graph_context { - llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_wavtokenizer_dec : public llama_model_base { + llama_model_wavtokenizer_dec(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2vl : public llm_graph_context { - llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plm : public llama_model_base { + llama_model_plm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3 : public llm_graph_context { - llm_build_qwen3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bailingmoe : public llama_model_base { + llama_model_bailingmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3moe : public llm_graph_context { - llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bailingmoe2 : public llama_model_base { + llama_model_bailingmoe2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vl : public llm_graph_context { - llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_seed_oss : public llama_model_base { + llama_model_seed_oss(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vlmoe : public llm_graph_context { - llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dots1 : public llama_model_base { + llama_model_dots1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3next : public llm_build_delta_net_base { - llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_arcee : public llama_model_base { + llama_model_arcee(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_afmoe : public llama_model_base { + llama_model_afmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen35 : public llm_build_delta_net_base { - llm_build_qwen35(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int * sections, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_ernie4_5 : public llama_model_base { + llama_model_ernie4_5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_ernie4_5_moe : public llama_model_ernie4_5 { + llama_model_ernie4_5_moe(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -// TODO: derive llm_build_delta_net_base instead -struct llm_build_qwen35moe : public llm_build_delta_net_base { - llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int * sections, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_paddleocr : public llama_model_ernie4_5 { + llama_model_paddleocr(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_hunyuan_moe : public llama_model_base { + llama_model_hunyuan_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen : public llm_graph_context { - llm_build_qwen(const llama_model & model, const llm_graph_params & params); + +struct llama_model_hunyuan_vl : public llama_model_base { + llama_model_hunyuan_vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_refact : public llm_graph_context { - llm_build_refact(const llama_model & model, const llm_graph_params & params); + +struct llama_model_hunyuan_dense : public llama_model_hunyuan_vl { + llama_model_hunyuan_dense(const struct llama_model_params & params) : llama_model_hunyuan_vl(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_hunyuan_vl + + using graph = llama_model_hunyuan_vl::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rnd1 : public llm_graph_context { - llm_build_rnd1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_smollm3 : public llama_model_base { + llama_model_smollm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6 : public llm_build_rwkv6_base { - llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); + +struct llama_model_openai_moe : public llama_model_base { + llama_model_openai_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { - llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_falcon_h1 : public llama_model_base { + llama_model_falcon_h1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv7 : public llm_build_rwkv7_base { - llm_build_rwkv7(const llama_model & model, const llm_graph_params & params); + +struct llama_model_lfm2 : public llama_model_base { + llama_model_lfm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_seed_oss : public llm_graph_context { - llm_build_seed_oss(const llama_model & model, const llm_graph_params & params); + +struct llama_model_lfm2moe : public llama_model_base { + llama_model_lfm2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + using graph = llama_model_lfm2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_smallthinker : public llm_graph_context { - llm_build_smallthinker(const llama_model & model, const llm_graph_params & params); + +struct llama_model_smallthinker : public llama_model_base { + llama_model_smallthinker(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_smollm3 : public llm_graph_context { - llm_build_smollm3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_grovemoe : public llama_model_base { + llama_model_grovemoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_stablelm : public llm_graph_context { - llm_build_stablelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_apertus : public llama_model_base { + llama_model_apertus(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder2 : public llm_graph_context { - llm_build_starcoder2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_minimax_m2 : public llama_model_base { + llama_model_minimax_m2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder : public llm_graph_context { - llm_build_starcoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_cogvlm : public llama_model_base { + llama_model_cogvlm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_step35_iswa : public llm_graph_context { - llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_pangu_embed : public llama_model_base { + llama_model_pangu_embed(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_t5 : public llm_graph_context { - llm_build_t5(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3next : public llama_model_base { + llama_model_qwen3next(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_t5encoder : public llm_build_t5 { - llm_build_t5encoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen35 : public llama_model_base { + llama_model_qwen35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_wavtokenizer_dec : public llm_graph_context { - llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen35moe : public llama_model_base { + llama_model_qwen35moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mistral3 : public llama_model_base { + llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mimo2 : public llama_model_base { + llama_model_mimo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_xverse : public llm_graph_context { - llm_build_xverse(const llama_model & model, const llm_graph_params & params); + +struct llama_model_kimi_linear : public llama_model_base { + llama_model_kimi_linear(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + + std::pair build_kda_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + std::pair build_kda_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_step35 : public llama_model_base { + llama_model_step35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 5c6a1b5e1bc..e9b79ffc6dc 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -1,6 +1,69 @@ #include "models.h" -llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 3; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_modern_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); + +} + +std::unique_ptr llama_model_modern_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 8596bbb2024..cfc60e8de29 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,6 +1,70 @@ #include "models.h" -llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mpt::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + // FIXME test-llama-archs crashes if q_norm is created + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_mpt::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/nemotron-h-moe.cpp b/examples/talk-llama/models/nemotron-h-moe.cpp new file mode 100644 index 00000000000..a59cc6c9fbd --- /dev/null +++ b/examples/talk-llama/models/nemotron-h-moe.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_nemotron_h_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index dc07d43df58..865461f61db 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -1,6 +1,127 @@ #include "models.h" -llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : +void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B + case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } +} + +std::unique_ptr llama_model_nemotron_h::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -60,7 +181,7 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_nemotron_h::graph::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, int64_t n_embd_head, @@ -76,7 +197,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { +ggml_tensor * llama_model_nemotron_h::graph::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 054b16fe0ef..0c72ed297aa 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -1,6 +1,52 @@ #include "models.h" -llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_nemotron::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index da68024a34d..f00d6eddfc9 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_layer == 28) { + type = LLM_TYPE_250M; + } +} + +void llama_model_neo_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_neo_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_neo_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp new file mode 100644 index 00000000000..a17abe2c269 --- /dev/null +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_nomic_bert_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp new file mode 100644 index 00000000000..5a8a5584457 --- /dev/null +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_nomic_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index a9974025f07..161035e72bc 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_olmo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 308d2a600c2..9633f269965 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -1,7 +1,68 @@ #include "models.h" +void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_olmo2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -146,5 +207,5 @@ llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_grap } // Explicit template instantiations -template struct llm_build_olmo2; -template struct llm_build_olmo2; +template struct llama_model_olmo2::graph; +template struct llama_model_olmo2::graph; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index ed46a00ef90..4bb9013054c 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_olmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe.cpp similarity index 51% rename from examples/talk-llama/models/openai-moe-iswa.cpp rename to examples/talk-llama/models/openai-moe.cpp index 50992b8d506..13a590ce646 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -1,6 +1,67 @@ #include "models.h" -llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openai_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_openai_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 514ac33517f..b4128e116e7 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; + const int64_t n_ff = hparams.n_ff(i); + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_openelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index a5874b6dee7..7ace0a5139d 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_orion::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_orion::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_orion::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 56cb1d94c5f..1c0eadefa98 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -1,6 +1,10 @@ #include "models.h" -llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) : +std::unique_ptr llama_model_paddleocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // NOTE: same with qwen2vl.cpp, but bias tensors are optional diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embed.cpp similarity index 53% rename from examples/talk-llama/models/pangu-embedded.cpp rename to examples/talk-llama/models/pangu-embed.cpp index 53464f21d22..41b7e2ac23e 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -1,6 +1,60 @@ #include "models.h" -llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 + case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_pangu_embed::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // weight tensors + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_pangu_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 0fb3ffa2e63..a333602c72d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phi2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_phi2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 39af285d3c5..0a65e91fefa 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -1,7 +1,71 @@ #include "models.h" +void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); + } +} + +void llama_model_phi3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_phi3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -128,5 +192,5 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ } // Explicit template instantiations -template struct llm_build_phi3; -template struct llm_build_phi3; +template struct llama_model_phi3::graph; +template struct llama_model_phi3::graph; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp new file mode 100644 index 00000000000..4575d6139cf --- /dev/null +++ b/examples/talk-llama/models/phimoe.cpp @@ -0,0 +1,55 @@ +#include "models.h" + +void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phimoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_phimoe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4d5c84506c2..4c16c20a0d4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -1,6 +1,42 @@ #include "models.h" -llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_plamo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b6142daebd9..29c8702606a 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -1,8 +1,109 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : +void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba parameters + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k(); + const uint32_t v_dim = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_plamo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -95,7 +196,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, const llama_model & model, @@ -150,7 +251,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv return cur; } -ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 67844c09f24..849f1579e63 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -1,7 +1,74 @@ #include "models.h" +void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 8; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } +} + +std::unique_ptr llama_model_plamo3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : +llama_model_plamo3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t head_dim_q = hparams.n_embd_head_k(); const int64_t head_dim_v = hparams.n_embd_head_v(); @@ -126,5 +193,5 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr } // Explicit template instantiations -template struct llm_build_plamo3; -template struct llm_build_plamo3; +template struct llama_model_plamo3::graph; +template struct llama_model_plamo3::graph; diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index abce6b34d04..57f5995103b 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_plm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k())); const uint32_t n_embd_head_qk_rope = hparams.n_rot(); diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 44e75d87437..cdc076cdf77 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } +} + +std::unique_ptr llama_model_qwen::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 2892dd75087..6320458a13b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 5f0a6861b68..7587c802c68 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -1,6 +1,67 @@ #include "models.h" -llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } +} + +std::unique_ptr llama_model_qwen2moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index da7937c7667..1a40fa89be4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -1,6 +1,45 @@ #include "models.h" -llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); +} +// fall through + +void llama_model_qwen2vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen2vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 883dd5f9a90..fa656c84ea0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 87790f08e4e..f276be61ba8 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -1,8 +1,96 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -87,7 +175,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -std::pair llm_build_qwen35::build_qkvz( +std::pair llama_model_qwen35::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t n_seqs = ubatch.n_seqs; @@ -103,7 +191,7 @@ std::pair llm_build_qwen35::build_qkvz( return { qkv_mixed, z }; } -ggml_tensor * llm_build_qwen35::build_norm_gated( +ggml_tensor * llama_model_qwen35::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -114,7 +202,7 @@ ggml_tensor * llm_build_qwen35::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen35::build_layer_attn( +ggml_tensor * llama_model_qwen35::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -195,7 +283,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( return cur; } -ggml_tensor * llm_build_qwen35::build_layer_attn_linear( +ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -369,7 +457,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Qwen3.5 does not use MoE FFN GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 7dc6a23c751..cf05dc9d61c 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -1,8 +1,109 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } +} + +std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -87,7 +188,7 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -std::pair llm_build_qwen35moe::build_qkvz( +std::pair llama_model_qwen35moe::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t n_seqs = ubatch.n_seqs; @@ -103,7 +204,7 @@ std::pair llm_build_qwen35moe::build_qkvz( return { qkv_mixed, z }; } -ggml_tensor * llm_build_qwen35moe::build_norm_gated( +ggml_tensor * llama_model_qwen35moe::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -114,7 +215,7 @@ ggml_tensor * llm_build_qwen35moe::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen35moe ::build_layer_attn( +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -195,7 +296,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( return cur; } -ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -369,7 +470,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Check if this is an MoE layer GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 16bedba994d..4440b83aa45 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_qwen3moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 1beda70b7cf..cb1b4814caf 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -1,8 +1,113 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } +} + +std::unique_ptr llama_model_qwen3next::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -87,7 +192,7 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -ggml_tensor * llm_build_qwen3next::build_norm_gated( +ggml_tensor * llama_model_qwen3next::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -98,7 +203,7 @@ ggml_tensor * llm_build_qwen3next::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen3next::build_layer_attn( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -178,7 +283,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( return cur; } -std::pair llm_build_qwen3next::build_qkvz( +std::pair llama_model_qwen3next::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t d_inner = hparams.ssm_d_inner; @@ -259,7 +364,7 @@ std::pair llm_build_qwen3next::build_qkvz( } } -ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -468,7 +573,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen3next::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Check if this is an MoE layer if (model.layers[il].ffn_gate_inp != nullptr) { // MoE branch diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index faa5f2ef3c8..7871f8f7952 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -1,6 +1,56 @@ #include "models.h" -llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen3vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp similarity index 57% rename from examples/talk-llama/models/qwen3vl-moe.cpp rename to examples/talk-llama/models/qwen3vlmoe.cpp index 29ee8278a4d..b99143c8908 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -1,6 +1,66 @@ #include "models.h" -llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vlmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_qwen3vlmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; @@ -127,4 +187,3 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } - diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index 398eb368db0..f14f10917ff 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -1,6 +1,81 @@ #include "models.h" -llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_refact::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_refact::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_refact::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index a917c19f25a..325ee73ba5c 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -1,7 +1,67 @@ #include "models.h" +void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_rnd1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_rnd1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // RND1 is a Qwen3Moe AR model converted to diffusion model. -llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 032b219d6cb..2944711acec 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -1,6 +1,97 @@ #include "models.h" -llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); + GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); + } + +} + +std::unique_ptr llama_model_rwkv6::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index e84e5973820..6f7d1f5722f 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -1,6 +1,87 @@ #include "models.h" -llm_build_rwkv6qwen2::llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { +void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + int attn_key_value_size; + if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { + attn_key_value_size = attn_hidden_size; + } else { + attn_key_value_size = n_head_kv * head_size; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + // optional bias tensors + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); + + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_rwkv6qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 16ffa6901b9..b205e3935e1 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -1,6 +1,127 @@ #include "models.h" -llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + } + +} + +std::unique_ptr llama_model_rwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 6db8d9781fe..83e114740b6 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -1,6 +1,51 @@ #include "models.h" -llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_36B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_seed_oss::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); + + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_seed_oss::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 55d09ec325d..3214e7cbad3 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -1,7 +1,80 @@ #include "models.h" +void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; + } + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smallthinker::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } +} + +std::unique_ptr llama_model_smallthinker::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ +llama_model_smallthinker::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -113,5 +186,5 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, } // Explicit template instantiations -template struct llm_build_smallthinker; -template struct llm_build_smallthinker; +template struct llama_model_smallthinker::graph; +template struct llama_model_smallthinker::graph; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 83636dbf546..7adaf34c534 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; + + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smollm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_smollm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 9c19abd8835..8f613e55947 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -1,6 +1,54 @@ #include "models.h" -llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_stablelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_stablelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index cf9fe95c35b..58cf0ac0edc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -1,6 +1,62 @@ #include "models.h" -llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_starcoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index b6d4d5aac1a..45dae0602d4 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -1,6 +1,61 @@ #include "models.h" -llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } +} + +std::unique_ptr llama_model_starcoder2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35.cpp similarity index 52% rename from examples/talk-llama/models/step35-iswa.cpp rename to examples/talk-llama/models/step35.cpp index 86aa98909e7..c4789752d21 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -1,6 +1,108 @@ #include "models.h" -llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // full_attention layer only use half of the RoPE dimensions + hparams.n_rot_full = hparams.n_rot_full / 2; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_step35::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_step35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 9f9dfef4012..27a0711ba41 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -1,7 +1,125 @@ #include "models.h" +void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + hparams.dec_n_layer = hparams.n_layer; + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + + switch (hparams.n_layer) { + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_t5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = create_tensor( + tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_t5::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + template <> -llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_t5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -156,7 +274,7 @@ llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_par } template <> -llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_t5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp index 5c1f9eb4030..23c5f9b6a1c 100644 --- a/examples/talk-llama/models/t5encoder.cpp +++ b/examples/talk-llama/models/t5encoder.cpp @@ -1,3 +1,44 @@ #include "models.h" -llm_build_t5encoder::llm_build_t5encoder(const llama_model & model, const llm_graph_params & params) : llm_build_t5(model, params) {} +void llama_model_t5encoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_t5encoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_t5encoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a7776d9cdc9..a873e5d2e8f 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -1,6 +1,121 @@ #include "models.h" -llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_wavtokenizer_dec::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); +} + +void llama_model_wavtokenizer_dec::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); + + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + } + + GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); +} + +std::unique_ptr llama_model_wavtokenizer_dec::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 53085ec80f6..e4d111e622a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -1,6 +1,43 @@ #include "models.h" -llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_xverse::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_xverse::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); From f6f32a7f51c3e1c9fddb5ae55a7848221c667276 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:07:30 +0200 Subject: [PATCH 044/289] try to fix window cublas CI failure Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/25631391231/job/75237266964?pr=3803 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..df390a9179c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -944,7 +944,7 @@ jobs: cmake --version where cmake if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR -D__CUDA_NO_HALF_CONVERSIONS__ ) else ( set CUDA_FLAGS= ) From 1665885f769e1cefa429d375f0044406f0665989 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:39:16 +0200 Subject: [PATCH 045/289] Revert "try to fix window cublas CI failure" This reverts commit a4d91768aa2ae8cf7083650b3e4dc214413f92b7. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index df390a9179c..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -944,7 +944,7 @@ jobs: cmake --version where cmake if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR -D__CUDA_NO_HALF_CONVERSIONS__ + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR ) else ( set CUDA_FLAGS= ) From e0bfd3ae4d50efd2959b4ae6407210bf74c921ab Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:44:23 +0200 Subject: [PATCH 046/289] try using CCCL 12.4.127 with cuda 11.8.0 to fix CI failure --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..423b1b28b22 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -822,7 +822,7 @@ jobs: $NVTX_VER = "11.8.86" $VS_VER = "11.8.86" $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" + $CCCL_VER = "12.4.127" # Create the directory where the CUDA Toolkit will be installed mkdir -p $CUDA_TOOLKIT_DIR From 5b2d4af850edf31dc23e750a769128c4b0feac1a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 15:17:13 +0200 Subject: [PATCH 047/289] Revert "try using CCCL 12.4.127 with cuda 11.8.0 to fix CI failure" This reverts commit be867eadf553801eb7d1c383ed47a90fdd3d4b18. Sorry about this noise, I thought it was worth a try. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 423b1b28b22..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -822,7 +822,7 @@ jobs: $NVTX_VER = "11.8.86" $VS_VER = "11.8.86" $NVPROF_VER = "11.8.87" - $CCCL_VER = "12.4.127" + $CCCL_VER = "11.8.89" # Create the directory where the CUDA Toolkit will be installed mkdir -p $CUDA_TOOLKIT_DIR From 633de7f99e692fe5de95edb8eb9a778f74de548d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 06:38:12 +0200 Subject: [PATCH 048/289] devops : add spirv-headers to vulkan dockerfile --- .devops/main-vulkan.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 2be22e4d53b..077af4f1001 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -2,7 +2,7 @@ FROM ubuntu:24.04 AS build WORKDIR /app RUN apt-get update && \ - apt-get install -y build-essential wget cmake git libvulkan-dev glslc \ + apt-get install -y build-essential wget cmake git libvulkan-dev spirv-headers glslc \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . From b1ebddf154c38adfaff4448ac37b8564d066f559 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 07:59:24 +0200 Subject: [PATCH 049/289] ggml-cuda : add explicit casts to -INFINITY for float and half2 types This commit adds explicit casts to float for -INFINITY. The motivation for this is that in CUDA 11.8.0, the -INFINITY macro is defined as a double (a header provided NVCC). This triggers a warning and hence causes a CI failure in whisper.cpp. I belive that this header might have been updated in CUDA 12 which is why we don't see this warning. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/25713948217/job/75500081939?pr=3803 Refs: https://github.com/ggml-org/llama.cpp/issues/22824 --- ggml/src/ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..246a76193ca 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -582,9 +582,9 @@ template struct block_reduce_policy { static __device__ T sentinel() { if constexpr (std::is_same_v) { - return -INFINITY; + return -(float)INFINITY; } else if constexpr (std::is_same_v) { - return make_half2(-INFINITY, -INFINITY); + return make_half2(__float2half(-(float)INFINITY), __float2half(-(float)INFINITY)); } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); } From b6a4b32a88b743bc42d5f849e435384b14ddeab8 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 08:30:00 +0200 Subject: [PATCH 050/289] ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8 --- ggml/src/ggml-cuda/allreduce.cu | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 434689abd95..03d88968cd5 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -105,6 +105,20 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; // blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are // handled only by block 0 to avoid cross-block writes to the same slots. // --------------------------------------------------------------------------- + +// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530, +// so use the explicit intrinsics to avoid ambiguous implicit conversions. +template +static __device__ inline T ar_add(T a, T b) { + if constexpr (std::is_same_v) { + return __hadd(a, b); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + } else { + return a + b; + } +} + template static __global__ void ggml_cuda_ar_kernel( const T_dst * sendbuf, @@ -184,13 +198,13 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + recvbuf[off + k] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(wire[k])); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); recvbuf[tail + tid] = - ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(host_other[tail + tid])); } } } @@ -210,7 +224,7 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + dst[i] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(src[i])); } } From d04a1faaec814772ca29f801ad7a10f4c330e16f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 08:36:14 +0200 Subject: [PATCH 051/289] ci : update ONEAPI version to 2025.3.3-0-devel-ubuntu24.04 --- .devops/main-intel.Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 1b5859715d4..dbb60682dce 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -1,6 +1,6 @@ -ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 +ARG ONEAPI_VERSION=2025.3.3-0-devel-ubuntu24.04 -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build WORKDIR /app RUN apt-get update && \ @@ -16,7 +16,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ fi && \ make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}" -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build WORKDIR /app RUN apt-get update && \ From ea29be532eac424323fe690b1b876b3324ced417 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 11:15:56 +0200 Subject: [PATCH 052/289] squash! ci : update ONEAPI version to 2025.3.3-0-devel-ubuntu24.04 --- .devops/main-intel.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index dbb60682dce..86b901c1538 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -16,7 +16,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ fi && \ make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}" -FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS runtime WORKDIR /app RUN apt-get update && \ From db7bcdb791162f6a256babc73e6cbceebe7c5d8d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 14 May 2026 05:27:13 +0200 Subject: [PATCH 053/289] Revert "ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8" This reverts commit 5cd228494af3973294e90aad95b58c2ede400f43. Reverting in favor of: https://github.com/ggml-org/llama.cpp/pull/22994 --- ggml/src/ggml-cuda/allreduce.cu | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 03d88968cd5..434689abd95 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -105,20 +105,6 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; // blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are // handled only by block 0 to avoid cross-block writes to the same slots. // --------------------------------------------------------------------------- - -// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530, -// so use the explicit intrinsics to avoid ambiguous implicit conversions. -template -static __device__ inline T ar_add(T a, T b) { - if constexpr (std::is_same_v) { - return __hadd(a, b); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); - } else { - return a + b; - } -} - template static __global__ void ggml_cuda_ar_kernel( const T_dst * sendbuf, @@ -198,13 +184,13 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(wire[k])); + recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); recvbuf[tail + tid] = - ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(host_other[tail + tid])); + ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); } } } @@ -224,7 +210,7 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(src[i])); + dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); } } From 5a24c7538fcf5ccc04770b03fe569a98cf1b0f5d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 14 May 2026 05:28:56 +0200 Subject: [PATCH 054/289] Revert "ggml-cuda : add explicit casts to -INFINITY for float and half2 types" This reverts commit a2839b4404de473bc7af127b7b308d530afda024. Reverting this as after closer inspection these only warnings and not errors. --- ggml/src/ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 246a76193ca..10817505d9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -582,9 +582,9 @@ template struct block_reduce_policy { static __device__ T sentinel() { if constexpr (std::is_same_v) { - return -(float)INFINITY; + return -INFINITY; } else if constexpr (std::is_same_v) { - return make_half2(__float2half(-(float)INFINITY), __float2half(-(float)INFINITY)); + return make_half2(-INFINITY, -INFINITY); } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); } From dd706793ccdca94b01c5e3a39b000bbccc552502 Mon Sep 17 00:00:00 2001 From: Steve Lhomme Date: Sun, 10 May 2026 16:35:38 +0200 Subject: [PATCH 055/289] ggml: install ggml.pc in /pkgconfig (ggml/1480) That's always how it's done: https://github.com/search?q=path%3ACMakeLists.txt%20%22%24%7BCMAKE_INSTALL_LIBDIR%7D%2Fpkgconfig%22&type=code --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 672b37dffc3..4e65cd68b4e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -352,7 +352,7 @@ if (GGML_STANDALONE) @ONLY) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc - DESTINATION share/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) endif() # From 5f08683bb615fac03383d171737281f65050fe82 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Sun, 10 May 2026 16:45:00 +0200 Subject: [PATCH 056/289] metal : tighten input-position loop in kernel_conv_transpose_1d (ggml/1477) For a given output position j on the time axis, only input positions i such that i*s0 <= j < i*s0 + K contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] intersected with [0, IL-1]. That's at most ceil(K/s0) values (typically 2 for stride==K/2 transposed convs). The current kernel iterates the full IL range and filters with an `if`, amplifying per-thread work by IL/ceil(K/s0) (~160x for IL=320, K=10, s0=5 -- a representative codec-decoder shape). On Apple M1 the wasted work trips the macOS GPU watchdog (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) on long graphs. Compute i_min, i_max analytically before the inner loop and iterate only [i_min, i_max]. Output is bit-identical (same multiplies and adds in the same order); loop bound shrinks by IL/ceil(K/s0). Tested on M1 with a downstream consumer running a TTS codec at full T_codec; end-to-end codec decode ~3-4x faster, zero watchdog hits across long synthesis runs vs ~30% pre-patch. --- ggml/src/ggml-metal/ggml-metal.metal | 31 +++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaedeae..5c2ec8a4ab8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4850,15 +4850,32 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { - float v = 0.0f; + // For output position j on the time axis, only input positions + // i such that i*s0 <= j < i*s0 + K + // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] + // intersected with [0, IL-1]. That's at most ceil(K/s0) values + // (typically 2 for stride==K/2 transposed convs). + const int32_t j = tgpig[0]; + const int32_t s0 = args.s0; + const int32_t K = args.K; + const int32_t IL = args.IL; + + int32_t i_min; + { + int32_t a = j - K + 1; + i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0 + } + int32_t i_max = j / s0; + if (i_max > IL - 1) i_max = IL - 1; - for (int64_t c = 0; c < args.IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; - const int32_t input_offset = c * args.IL; + float v = 0.0f; + if (i_min <= i_max) { + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; - for (int64_t i = 0; i < args.IL; i++) { - if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { - v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + for (int32_t i = i_min; i <= i_max; i++) { + v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i]; } } } From 73f63f529539a740a3a81f87ba4984abfa7daf3d Mon Sep 17 00:00:00 2001 From: Oliver Walsh Date: Sun, 10 May 2026 16:32:41 +0100 Subject: [PATCH 057/289] ggml-virtgpu : include missing mutex header (llama/22810) Add missing `#include ` in ggml-backend-device.cpp. Fixes: #22809 Signed-off-by: Oliver Walsh --- ggml/src/ggml-virtgpu/ggml-backend-device.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index ec8156bb868..a978812cd90 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -1,5 +1,7 @@ #include "ggml-remoting.h" +#include + static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); From 4db2f450754b1aa41e45c3b1328555f48f2ba6fb Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 11 May 2026 13:01:47 +0800 Subject: [PATCH 058/289] Add OP im2col_3d (llama/22903) * add im2col_3d * format code * update the ops.md --- ggml/src/ggml-sycl/ggml-sycl.cpp | 9 + ggml/src/ggml-sycl/im2col.cpp | 442 ++++++++++++++++++++++++------- ggml/src/ggml-sycl/im2col.hpp | 8 +- 3 files changed, 367 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e7768b8bf61..57cc4ffb6f7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4159,6 +4159,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_im2col(ctx, dst); } +static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_im2col_3d(ctx, dst); +} + static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); @@ -4456,6 +4461,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_IM2COL: ggml_sycl_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_sycl_im2col_3d(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_sycl_pool2d(ctx, dst); break; @@ -5175,6 +5183,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_UPSCALE: return true; case GGML_OP_SUM: diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 6d75d34d83f..7bf3584fb97 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -12,125 +12,389 @@ #include "im2col.hpp" -#include -#include // For std::is_same_v - -#include "ggml.h" +#define MAX_GRIDDIM_Z 65535 template -static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, - int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, - int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { - const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); - - // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { - const int64_t ksize = OW * KH; - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; - - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; - - const int64_t iiw = (ix * s0) + (kx * d0) - p0; - const int64_t iih = (oh * s1) + (ky * d1) - p1; - - const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); - - const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); - const int64_t offset_src = offset_src_base + (iih * IW) + iiw; - - const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); - const float src_val = out_of_bounds ? 0.0f : x[offset_src]; - - if constexpr (std::is_same_v) { - dst[offset_dst] = sycl::half(src_val); - } else if constexpr (std::is_same_v) { - dst[offset_dst] = src_val; - } +static void im2col_kernel( + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, + int s0, int s1, int p0, int p1, int d0, int d1) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KH_KW) { + return; } -} -template -static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; - // decrease global range when it exceeds the max int - int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - sycl::range<3> block_nums(batch * IC, OH, num_blocks); - sycl::range<3> local_range(1, 1, local_size); + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t CHW = IC * KH * KW; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, - p0, p1, d0, d1, item_ct1); - }); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } + } + + GGML_UNUSED(IC); + GGML_UNUSED(KH); } -static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, - int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, - int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - if (!stream->get_device().has(sycl::aspect::fp16)) { - throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), - "Device does not support half precision (fp16) operations!"); - } - im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, - p1, d0, d1, stream); +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +template +static void im2col_sycl(const float * x, + T * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:73: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); + }); +} + +static void im2col_sycl_f16(const float * x, + sycl::half * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } -static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, - int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, - d0, d1, stream); +static void im2col_sycl_f32(const float * x, + float * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; - const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; - const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; - const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; - const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; - const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; + + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + if(dst->type == GGML_TYPE_F16) { + im2col_sycl_f16(src1_d, (sycl::half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, + d0, d1, stream); + } else { + im2col_sycl_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_sycl(const float * src, + T * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:74: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_3d_kernel(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, + ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, + OD_OH_OW_IC_KD_KH_KW, OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, s0, s1, s2, p0, p1, p2, d0, d1, + d2); + }); +} + +static void im2col_3d_sycl_f16(const float * src, + sycl::half * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, stride_q, stride_z, stride_y, + stride_x, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_sycl_f32(const float * src, + float * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); - const int64_t batch = src1->ne[is_2D ? 3 : 2]; - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; - queue_ptr stream = ctx.stream(); + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; - if (dst->type == GGML_TYPE_F16) { - im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + if(dst->type == GGML_TYPE_F16) { + im2col_3d_sycl_f16(src1_d, (sycl::half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } else { - im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_3d_sycl_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } } diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp index dbbb248ddb4..976d1094636 100644 --- a/ggml/src/ggml-sycl/im2col.hpp +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -15,7 +15,9 @@ #include "common.hpp" -void ggml_sycl_op_im2col( - ggml_backend_sycl_context & ctx, ggml_tensor *dst); +#define SYCL_IM2COL_BLOCK_SIZE 256 + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_IM2COL_HPP From 0077a6d3320dcf3d72983a0ce4f0ba35e21d051e Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 11 May 2026 12:16:38 +0200 Subject: [PATCH 059/289] CUDA: directly include cuda/iterator (llama/22936) Before, we relied on a transient import from `cub/cub.cuh`, which is bad practice to do as cub may not always expose cuda/iterator --- ggml/src/ggml-cuda/argsort.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 0f3f017b534..c4f08091e79 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -4,6 +4,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) # define STRIDED_ITERATOR_AVAILABLE +# include # endif using namespace cub; #endif // GGML_CUDA_USE_CUB From c0c1f994b711114e907ac3250605b7542d7d19ec Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 11 May 2026 05:49:03 -0500 Subject: [PATCH 060/289] vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (llama/22589) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 253 +++++++-------- .../vulkan-shaders/flash_attn.comp | 176 +++++----- .../vulkan-shaders/flash_attn_base.glsl | 210 +++--------- .../vulkan-shaders/flash_attn_cm1.comp | 160 +++++----- .../vulkan-shaders/flash_attn_cm2.comp | 43 +-- .../vulkan-shaders/flash_attn_dequant.glsl | 123 +++++++ .../vulkan-shaders/flash_attn_mmq_funcs.glsl | 300 +++++++++++------- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 11 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 36 +-- 9 files changed, 632 insertions(+), 680 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0a7931002ab..7e450a559dd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -855,7 +855,7 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map pipeline_flash_attn_f32_f16; std::map, vk_pipeline> pipeline_fa_mask_opt; @@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); -static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) { result.block_rows /= 2; } @@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_rows); GGML_UNUSED(n_kv); - GGML_UNUSED(kv_type); + GGML_UNUSED(k_type); + GGML_UNUSED(v_type); GGML_UNUSED(f32acc); vk_fa_tuning_params result{}; @@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device } static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { - // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. - if (k_type != v_type) { - GGML_ASSERT(device->coopmat2); - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - } - FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT2: return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: @@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +// Whether scalar flash attention will use the MMQ path for the given k_type. +static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + return device->integer_dot_product && device->subgroup_clustered && + (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 || + k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 || + k_type == GGML_TYPE_Q8_0); +#else + GGML_UNUSED(device); + GGML_UNUSED(k_type); + return false; +#endif +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; -#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - uint32_t fa_sgs = fa.first.subgroup_size; \ - bool fa_ds = fa.first.subgroup_size == 0; \ - if (path == FAPATH) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } \ - } \ - } - - if (device->fp16) { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) - } else -#endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) - } - } else { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - + // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V + // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants. + + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_SCALAR) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); + const void * spv_data = nullptr; + size_t spv_size = 0; + if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) - } else + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } + else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_int8_data; + spv_size = flash_attn_f32_f16_fp32_int8_len; + } #endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } else { + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + } } + const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); } + #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT1) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const void * spv_data; + size_t spv_size; + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } } #endif + #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -#define CREATE_FA_CM2_MIXED() \ - for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - if (path == FA_COOPMAT2) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } \ - } \ - } \ - } \ - } if (device->coopmat2) { - CREATE_FA_CM2_MIXED(); + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT2) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + + const void * spv_data; + size_t spv_size; + const char * name; + if (aligned) { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; } + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0); + } } -#undef CREATE_FA_CM2_MIXED #endif -#undef CREATE_FA const int mul_mat_id_param_count = 5; @@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) { GGML_UNUSED(f32acc); + GGML_UNUSED(v_type); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; const uint32_t Br = params.block_rows; @@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const bool mmq = device->integer_dot_product && device->subgroup_clustered && - (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || - kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || - kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // kvsh uses D = HSV (K goes through kblocksh instead) kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - // block_a_cache size depends on quant type - uint32_t block_a_size; - switch (kv_type) { - case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; - default: block_a_size = 0; break; - } + // The mixed MMQ shader uses a superset block_a_cache that fits every + // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm. + // Single-scale types leave dm.y unused; non-Q5_* leave qh unused. + const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size; kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; } else { Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; @@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); - if (tuning_params.path != FA_COOPMAT2) { - GGML_ASSERT(k->type == v->type); - } - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); @@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { std::lock_guard guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { pipeline = it->second; @@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // mismatching K/V type is currently supported for coopmat2 only. - if (op->src[1]->type != op->src[2]->type && !coopmat2) { - return false; - } auto fa_kv_ok = [coopmat2](ggml_type t) { switch (t) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6e6bdabc92e..6ac095489b3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -22,6 +22,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -128,18 +129,20 @@ void main() { Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); -#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); - } -#else // Q4_0, Q4_1, Q5_0, Q5_1 - const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; - const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores + // the row-sum scaled by qd, used in k_dot_correction. + if (FaTypeK == FA_TYPE_Q8_0) { + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } + } else { + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } } -#endif #endif } barrier(); @@ -177,13 +180,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -257,21 +256,21 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; } } #else // MMQ - const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK); const uint quant_iters = Bc * HSK / 32 * ints_per_block; [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { const uint32_t iqs = (idx + tid) % ints_per_block; @@ -310,15 +309,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); @@ -335,15 +332,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); @@ -366,72 +361,47 @@ void main() { int32_t k_quants[d_per_step]; ACC_TYPEV2 k_dm; + // Q4_*/Q5_* take the block-8 fast path when one step covers a full + // block; Q8_0 always goes through the per-int get_k_qs* helpers + // (its qs is byte-packed, not nibble-packed). + const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0); + if (SHMEM_STAGING != 0) { const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); -#else k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { + if (block8_fast) { + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); [[unroll]] for (uint32_t d = 0; d < 4; d++) { uint vui = kblocksh[buf_ib].qs[d]; k_quants[d ] = int32_t( vui & 0x0F0F0F0F); k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; - uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + if (has_qh) { + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); } } } else { - const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); - const uint ib = coord / BLOCK_SIZE; - const uint iqs = (coord % BLOCK_SIZE); + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE_K; + const uint iqs = (coord % BLOCK_SIZE_K); -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); -#else - k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { -#if defined(DATA_A_Q5_0) - uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], - k_packed.k_data_packed16[k_offset + ib].qh[1])); -#elif defined(DATA_A_Q5_1) - uint qh = k_packed.k_data_packed16[k_offset + ib].qh; -#endif - [[unroll]] for (uint32_t d = 0; d < 4; d++) { -#if defined(A_TYPE_PACKED32) - uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; -#else - uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], - k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); -#endif - k_quants[d ] = int32_t( vui & 0x0F0F0F0F); - k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (qh >> (d * 4)) & 0xF; - uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset)); + + if (block8_fast) { + fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset); + [[unroll]] for (uint32_t d = 0; d < 8; d++) { + k_quants[d] = blk.qs[d]; } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); } @@ -516,14 +486,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -547,15 +517,13 @@ void main() { FLOAT_TYPEV4 Vf; if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_V) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else + } else { Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index efed3a73e22..9a7957da97b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -#if defined(DATA_A_F32) -layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; -#elif defined(A_TYPE_PACKED16) -layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; -#endif - -#if defined(A_TYPE_PACKED32) -layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; -layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; -#endif - -#ifndef BLOCK_SIZE -#define BLOCK_SIZE 1 -#endif - -#if defined(DATA_A_F32) -#undef BLOCK_SIZE -#define BLOCK_SIZE 4 -#define BLOCK_BYTE_SIZE 16 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - // iqs is currently always zero in the flash attention shaders - if (binding_idx == BINDING_IDX_K) { - return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); - } else { - return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); - } -} -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 -#elif defined(DATA_A_Q4_1) -#define BLOCK_BYTE_SIZE 20 -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } -} -#endif - -#if defined(DATA_A_Q5_0) -#define BLOCK_BYTE_SIZE 22 -#elif defined(DATA_A_Q5_1) -#define BLOCK_BYTE_SIZE 24 -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = v_packed.v_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } -} -#endif - - -#if defined(DATA_A_IQ4_NL) -#define BLOCK_BYTE_SIZE 18 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); +// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the +// host can pass the type directly. Keep in sync with ggml.h. +#define FA_TYPE_F32 0u +#define FA_TYPE_F16 1u +#define FA_TYPE_Q4_0 2u +#define FA_TYPE_Q4_1 3u +#define FA_TYPE_Q5_0 6u +#define FA_TYPE_Q5_1 7u +#define FA_TYPE_Q8_0 8u +#define FA_TYPE_Q1_0 41u + +// Number of matrix elements per buffer block, derived from the K/V type spec +// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 +// and bypasses the dequant path entirely. Quants follow their ggml block sizes. +uint fa_block_elems(uint ty) { + switch (ty) { + case FA_TYPE_F32: return 4u; + case FA_TYPE_F16: return 1u; + case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere + default: return 1u; } } -#endif -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } else { - const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); +// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte +// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number +// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R. +uint fa_quant_r_mmq(uint ty) { + switch (ty) { + case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0); + default: return 1u; } } -#endif + +// These can't be `const` globals because GLSL forbids function calls in global +// const initializers, even when the spec constants would let the driver fold +// them. Macros expand at the use site and fold after specialization. +#define BLOCK_SIZE_K fa_block_elems(FaTypeK) +#define BLOCK_SIZE_V fa_block_elems(FaTypeV) +// F16 reads f16 elements directly from the binding; everything else routes +// through dequantize4 / the MMQ helpers to unpack from the packed block layout. +#define USE_DECODE_K (FaTypeK != FA_TYPE_F16) +#define USE_DECODE_V (FaTypeV != FA_TYPE_F16) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 526e8da384e..bffcc095be3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -14,6 +14,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -127,13 +128,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -227,14 +224,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; @@ -256,47 +253,40 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If K is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) { -#endif - barrier(); - [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { - uint32_t col_vec = (idx + tid) % (MatBr / 4); - uint32_t row = (idx + tid) / (MatBr / 4); - if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); - if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); -#endif - } + // For quants we always need to dequant into kvsh; for f16 we can load + // directly from global memory when alignment / bounds allow it. + const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; + if (stage_k) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { + if (USE_DECODE_K) { + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } + } - kvsh[row * kvsh_stride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; + } } + barrier(); } - barrier(); -#if BLOCK_SIZE == 1 - } -#endif -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) -#endif - { + if (stage_k) { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); - } -#if BLOCK_SIZE == 1 - else { + } else { const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); } -#endif } else { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); @@ -397,14 +387,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { f16vec4 V_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -431,36 +421,33 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If V is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - // For f16, only preload if not aligned - if (KV_bounds_check) { -#endif - [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { - const uint idx = i * gl_WorkGroupSize.x + tid; - const uint row = idx / v_cols; - const uint col = idx % v_cols; - - const uint v_row = j * Bc + row; - const uint v_col = hsv_tile * MatBc * row_split + col * 4; - - const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; - const uint ib = coord / BLOCK_SIZE; - const uint iqs = coord % BLOCK_SIZE; - - if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { -#if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; -#endif - } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + // For quants we always preload via kvsh. For f16 we only preload when + // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). + const bool stage_v = USE_DECODE_V || KV_bounds_check; + if (stage_v) { + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col; + const uint ib = coord / BLOCK_SIZE_V; + const uint iqs = coord % BLOCK_SIZE_V; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { + if (USE_DECODE_V) { + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + } + } else { + kvsh[row * vsh_stride + col] = f16vec4(0.0f); + } } } - -#if BLOCK_SIZE == 1 - } -#endif } barrier(); @@ -471,15 +458,12 @@ void main() { coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (!KV_bounds_check) { + if (!USE_DECODE_V && !KV_bounds_check) { // F16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else -#endif - { + } else { const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8a7bbaeb92c..141bb870883 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; -uint fa_block_elems(uint ty) { - switch (ty) { - case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) - case 1u: return 1u; // GGML_TYPE_F16 - case 2u: return uint(QUANT_K_Q4_0); - case 3u: return uint(QUANT_K_Q4_1); - case 6u: return uint(QUANT_K_Q5_0); - case 7u: return uint(QUANT_K_Q5_1); - case 8u: return uint(QUANT_K_Q8_0); - case 41u: return uint(QUANT_K_Q1_0); - default: - return 1u; - } -} - float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeV) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl new file mode 100644 index 00000000000..02106f33cbe --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -0,0 +1,123 @@ +// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V) +// covering every supported FA element type, plus an uber dequantize4() that +// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver +// folds away every path except the one matching the K/V type for this pipeline. +// +// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by +// flash_attn_cm2.comp, which has its own buffer_reference-based decode path. +// +// We use macros (rather than per-quant decode functions taking a struct) on +// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16 +// when FLOAT16 isn't defined, which makes float16-containing struct values +// illegal to return from / pass to functions. Macros expand inline where the +// float16 stays in storage and is converted to FLOAT_TYPE at use. + +// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl +// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32. +layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32; +layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32; + +layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0; +layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0; +layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1; +layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1; +layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0; +layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0; +layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1; +layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1; +layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; +layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; + +// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 +// views, used by the MMQ K-side hot path for fast 4-uint loads. +layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; +layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32; + +// Per-quant decode bodies are expanded once for the K view set and once for +// the V view set. The macros take the buffer name as a parameter. +#define FA_DEQUANT4_F32(BUF) \ + return FLOAT_TYPEV4(BUF.data[a_offset + ib]); + +#define FA_DEQUANT4_Q4_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \ +} + +#define FA_DEQUANT4_Q4_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q5_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = uint(BUF.data[a_offset + ib].qh[0]) \ + | (uint(BUF.data[a_offset + ib].qh[1]) << 16); \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \ +} + +#define FA_DEQUANT4_Q5_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = BUF.data[a_offset + ib].qh; \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q8_0(BUF) { \ + const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \ + const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ +} + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + switch (FaTypeK) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + } + } else { + switch (FaTypeV) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + } + } + return FLOAT_TYPEV4(0); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl index e14e62d546a..6bf10a7cffd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -1,149 +1,203 @@ -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q4_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - return int32_t(vui & 0x0F0F0F0F); -} -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q5_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], - k_packed.k_data_packed16[a_offset + ib].qh[1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - uint qh_bits = (qh >> iqs) & 0xF; - return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -} -#endif - -#if defined(DATA_A_Q8_0) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); -} -#endif +// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and +// reads from the matching aliased K binding declared in flash_attn_dequant.glsl. +// Spec-constant specialization folds the unused paths. -#if defined(DATA_A_IQ4_NL) int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - u8vec4 idx = unpack8(vui & 0x0F0F0F0F); - return pack32(i8vec4(kvalues_iq4nl_const[idx.x], - kvalues_iq4nl_const[idx.y], - kvalues_iq4nl_const[idx.z], - kvalues_iq4nl_const[idx.w])); + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q4_1: { // uses packed32 alias + uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q5_0: { + uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16 + uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed_q5_1.data[a_offset + ib].qh; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2], + k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1])); + } + default: return 0; + } } -#endif -#if QUANT_AUXF == 1 -FLOAT_TYPE get_k_d(uint ib, uint a_offset) { - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); -} -#else -FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { - return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0) +// return (d, 0) so call sites always see the same shape. +FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0); + default: return FLOAT_TYPEV2(0); + } } -#endif void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { -#if defined(DATA_A_Q4_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_Q4_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; -#elif defined(DATA_A_Q5_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - if (iqs == 0) { - kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], - k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + // kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need + // explicit casts. The bit pattern is what we care about here -- the actual + // signed/unsigned interpretation happens downstream in the dot product. + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + break; + } + case FA_TYPE_Q4_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]); + break; + } + case FA_TYPE_Q5_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0], + k_packed_q5_0.data[a_offset + global_ib].qh[1])); + } + break; + } + case FA_TYPE_Q5_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]); + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh; + } + break; + } + case FA_TYPE_Q8_0: { + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1])); + break; + } } -#elif defined(DATA_A_Q5_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { - kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + // Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair. + switch (FaTypeK) { + case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break; + } } -#elif defined(DATA_A_Q8_0) - kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_IQ4_NL) - const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); - const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); - kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], - kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); - kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], - kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); -#endif +} - if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); -#else - kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); -#endif +// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed +// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads +// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble +// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned). +// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't +// share this nibble shape. +// +// Returned via a struct so the caller's k_quants array (sized from spec +// constants) doesn't need to match a fixed[8] out-parameter type. +struct fa_k_qs_block8 { + int32_t qs[8]; +}; + +fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) { + fa_k_qs_block8 r; + uint qh = 0; + if (FaTypeK == FA_TYPE_Q5_0) { + qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + } else if (FaTypeK == FA_TYPE_Q5_1) { + qh = k_packed_q5_1.data[a_offset + ib].qh; } + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = 0; + switch (FaTypeK) { + case FA_TYPE_Q4_0: { // packed16 + vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q4_1: { // packed32 alias + vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d]; + break; + } + case FA_TYPE_Q5_0: { // packed16 + vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q5_1: { // packed32 alias + vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d]; + break; + } + } + r.qs[d ] = int32_t( vui & 0x0F0F0F0F); + r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (qh >> (d * 4)) & 0xFu; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu; + r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + return r; } int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); - uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; - return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - return kblocksh[buf_ib].qs[pos]; -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: + case FA_TYPE_Q4_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + } + case FA_TYPE_Q5_0: + case FA_TYPE_Q5_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return kblocksh[buf_ib].qs[pos]; + } + default: return 0; + } } ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { -#if defined(DATA_A_Q4_0) - return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q5_0) - return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) - return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; -#else - return ACC_TYPE(0.0); -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q4_1: + case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; + default: return ACC_TYPE(0.0); + } } void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { kblocksh[buf_ib].qs[iqs] = 0; -#if defined(DATA_A_IQ4_NL) - kblocksh[buf_ib].qs[iqs + 4] = 0; -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); -#else kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); -#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 10552d013a2..79c933f40cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -1,4 +1,13 @@ -#if defined(DATA_A_Q4_0) +#if defined(FA_MMQ_MIXED) +// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0. +// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale +// types (Q4_0/Q5_0/Q8_0) leave dm.y unused. +struct block_a_cache { + int32_t qs[8]; + uint32_t qh; + FLOAT_TYPEV2 dm; +}; +#elif defined(DATA_A_Q4_0) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 6f2a929c40c..d99b2b5d802 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -643,42 +643,22 @@ void process_shaders() { if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); #endif - } - - for (const auto& tname : type_names) { - if (tname == "bf16") continue; - if (fp16) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); #endif - } + } - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (tname != "f32") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); - } + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); #endif - } - } } } From 449b33fc8f6aadf267e3b577622deae45c81ea0c Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 11 May 2026 18:42:08 +0200 Subject: [PATCH 061/289] Ggml/cuda snake fusion hardening (llama/22912) * cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b92a208705d..e25be3592fd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 From 287f637fb15f3c59d0f579f86c4b59d3ebfbc443 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Mon, 11 May 2026 19:48:29 +0200 Subject: [PATCH 062/289] CUDA: handle OW > 65535 in im2col (2D and 3D) (llama/22944) `im2col_cuda` and `im2col_3d_cuda` both dispatch with `block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet at 11 s lands at OW = 176000 -- and the launch returns `invalid configuration argument`. Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel` and 3D `im2col_3d_kernel` need the same fix. Bit-identical for OW <= 65535 (single iteration of the new outer loop). Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s / 16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns `invalid configuration argument`, post-fix runs to completion. Existing test-backend-ops im2col cases unchanged. --- ggml/src/ggml-cuda/im2col.cu | 61 +++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 56dc0545742..28c79ab462e 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template @@ -18,22 +19,23 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } @@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, From ea4652c42704fc298605fb639dedc40937845c78 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Mon, 11 May 2026 11:57:26 -0700 Subject: [PATCH 063/289] opencl: add q4_1 MoE for Adreno (llama/22856) * Q4_1 MoE CLC pass sanity check * remove unnecessary code * opencl: remove unnecessary asserts and reformat * opencl: fix supports_op for q4_1 moe * q4_1 moe is supported by Adreno with certain shapes --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 366 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 90 +++++ .../kernels/gemm_moe_q4_1_f32_ns.cl | 254 ++++++++++++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 119 ++++++ 5 files changed, 798 insertions(+), 33 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ffde6a4f063..7edb3eb4e9c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -104,6 +104,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_q4_0_f32_ns gemv_moe_q4_0_f32_ns + gemm_moe_q4_1_f32_ns + gemv_moe_q4_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4e6f6fb43d2..73a58f74a94 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -544,6 +544,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; + cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -602,6 +603,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; + cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -958,6 +960,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2856,6 +2860,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable " " -cl-fast-relaxed-math"; + // gemv_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3749,11 +3785,14 @@ struct ggml_tensor_extra_cl_q4_1 { CL_CHECK(clReleaseMemObject(m)); m = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; m_img = nullptr; size_q = 0; @@ -4189,6 +4228,35 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm return GGML_STATUS_SUCCESS; } +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); +} + +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -4385,6 +4453,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } } + // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, + // the quantizations here currently do not - they are only supported by Adreno with certain shapes + if (op->src[0]->type == GGML_TYPE_Q4_1) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (op->src[1]->type == GGML_TYPE_F32) { + return use_adreno_moe_kernels(backend_ctx, op->src[0]) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]); + } +#endif + return false; + } return false; case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -4555,6 +4635,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4868,35 +4954,6 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff return GGML_STATUS_SUCCESS; } -// The optimized gemm and gemv kernels are used for large matrices without batch. -// tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - int64_t threshold_ne0 = 512; - int64_t threshold_ne1 = 512; - if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && - backend_ctx->adreno_cl_compiler_version.type != DX) { - threshold_ne0 = 128; - threshold_ne1 = 128; - } - return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && - tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - GGML_UNUSED(backend_ctx); - int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); -} - -inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - - bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); - - size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; - - return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 -} - static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -5097,15 +5154,54 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // normal q4_1 repack +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -5862,6 +5958,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_q; static ggml_cl_buffer buf_trans_m; @@ -12862,6 +12988,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13131,6 +13258,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; + + if (strstr(src0->name, "as") != NULL) { + moe_router_reoerder(backend, src2, ne20); + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c87450dc49e..5bbf09710f9 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -370,6 +370,96 @@ kernel void kernel_restore_block_q4_1_noshuffle( } } +kernel void kernel_convert_block_q4_1_trans4_ns( + __global struct block_q4_1 * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_1_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_m, + __global struct block_q4_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_dm_offset]; + b->m = src_m[src_dm_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..e2574ae0187 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_1(q4, a_f16, scale, m) \ + a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \ + a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \ + a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \ + a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \ + a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \ + a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \ + a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \ + a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \ + a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \ + a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \ + a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \ + a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \ + a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \ + a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \ + a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \ + a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_1_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q4_1 block + uint sm_offset = s_sub_offset + get_global_id(0); + half s = src0_d[sm_offset]; + half m = src0_m[sm_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..3739a215705 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_1_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From 8ec91c91e17f043dc920439b0ef2b428b037614d Mon Sep 17 00:00:00 2001 From: guyfischman <138163913+guyfischman@users.noreply.github.com> Date: Tue, 12 May 2026 07:15:02 +0200 Subject: [PATCH 064/289] metal : promote mul_mv/mul_mm batch divisors to function constants (llama/22711) * metal : promote mul_mv/mul_mm batch divisors to function constants * metal : take op directly in get_pipeline_mul_mv_ext --- ggml/src/ggml-metal/ggml-metal-device.cpp | 46 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 165 +++++++++++----------- 4 files changed, 127 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f14..f0147af84c1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083b0..1f212a92f98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c875c..a114391c2e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5c2ec8a4ab8..2d45de8cce2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7479,10 +7482,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7584,10 +7587,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7758,10 +7761,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7870,10 +7873,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -8006,10 +8009,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8111,10 +8114,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8219,10 +8222,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8338,10 +8341,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8450,10 +8453,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8562,10 +8565,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8675,10 +8678,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8774,10 +8777,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8883,10 +8886,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8992,10 +8995,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9103,10 +9106,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9321,6 +9324,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9347,11 +9354,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9490,10 +9497,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; From 20895abdbd5f076118be9f8ddd0170448a399e79 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 12 May 2026 04:41:58 -0500 Subject: [PATCH 065/289] vulkan: Check shared memory size for mmq shaders (llama/22693) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 168 ++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7e450a559dd..90ea7cc1a9b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -681,6 +681,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + struct GpuPipelineConfig { // GPU architecture identifier. // Example: vk_device_architecture::AMD_GCN @@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + + device->mul_mat_l_int[i] = true; + device->mul_mat_m_int[i] = true; + device->mul_mat_s_int[i] = true; + device->mul_mat_id_l_int[i] = true; + device->mul_mat_id_m_int[i] = true; + device->mul_mat_id_s_int[i] = true; } @@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -7312,35 +7431,42 @@ static void ggml_vk_matmul( ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); From be5a35cceebee8b70c75ced22336b4ae4a8882af Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Tue, 12 May 2026 03:15:34 -0700 Subject: [PATCH 066/289] vulkan: Fix Windows performance regression on Intel GPU BF16 workloads for Xe2 and newer (llama/22461) * refactor * Use l_warptile only when coopamt is available for BF16 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 90ea7cc1a9b..a0a556206d5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4260,11 +4260,6 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { - // Xe2/Xe3 - bf16 warptile performance tuning - l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; - } - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5689,19 +5684,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; - case VK_VENDOR_ID_INTEL: - if (!device->coopmat_support || device->architecture != INTEL_XE2) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } else { - device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel - device->mul_mat_id_l[i] = true; - } + case VK_VENDOR_ID_INTEL: { + // Current Windows driver does not expose BF16 support. + // We only want to use l_warptile if coopmat is available and is Xe2+ + const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2; + const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat; + device->mul_mat_l[i] = use_l_warptile; + device->mul_mat_id_l[i] = use_l_warptile; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + } case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; From a9bcbf559577c0a637819a704ad36091f0953fec Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Tue, 12 May 2026 10:27:04 -0400 Subject: [PATCH 067/289] ggml-webgpu: address precision issues for multimodal (llama/22808) * fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 193 ++++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 ++- .../wgsl-shaders/flash_attn_tile.wgsl | 87 +++++--- .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 10 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 112 +++++----- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 49 +++-- 6 files changed, 295 insertions(+), 186 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c6dc2c21147..932a01d385e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; }; @@ -531,7 +532,9 @@ enum ggml_webgpu_flash_attn_path : uint32_t { }; struct ggml_webgpu_flash_attn_pipeline_key { + ggml_type q_type; ggml_type kv_type; + ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; @@ -542,16 +545,19 @@ struct ggml_webgpu_flash_attn_pipeline_key { uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && - has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); @@ -595,14 +601,14 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ } inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - uint32_t path) { + const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_decisions & decisions) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; bool kv_direct = false; - if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { kv_direct_align = context.sg_mat_k; } kv_direct = (context.src1->type == GGML_TYPE_F16) && @@ -611,7 +617,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ } ggml_webgpu_flash_attn_pipeline_key key = {}; + key.q_type = context.src0->type; key.kv_type = context.src1->type; + key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; @@ -619,13 +627,14 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = path; + key.path = decisions.path; return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { - uint32_t head_dim_v; - uint32_t wg_size; + uint32_t head_dim_v; + uint32_t wg_size; + ggml_type dst_type; }; struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { @@ -633,13 +642,14 @@ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.wg_size); + ggml_webgpu_hash_combine(seed, key.dst_type); return seed; } }; inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { - return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type; } struct ggml_webgpu_flash_attn_blk_pipeline_key { @@ -662,19 +672,32 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct) { + bool kv_direct, + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - f16_elems += q_tile * head_dim_qk; // q_shmem + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + f32_elems += head_dim_qk; // q_shmem + if (!kv_direct) { + f32_elems += kv_tile * max_head_dim; // kv_shmem + } + f32_elems += head_dim_v; // o_shmem + if (has_mask) { + f32_elems += kv_tile; // mask_shmem + } + f32_elems += kv_tile; // inter_shmem + return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; + } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { - f16_elems += kv_tile * max_head_dim; // kv_shmem + f32_elems += kv_tile * max_head_dim; // kv_shmem } - f16_elems += q_tile * head_dim_v; // o_shmem + f32_elems += q_tile * head_dim_v; // o_shmem if (has_mask) { - f16_elems += q_tile * kv_tile; // mask_shmem + f32_elems += q_tile * kv_tile; // mask_shmem } - f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile * kv_tile; // inter_shmem f32_elems += q_tile; // row_max_shmem f32_elems += q_tile; // exp_sum_shmem return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; @@ -684,27 +707,27 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const ggml_webgpu_flash_attn_pipeline_key & key) { const size_t limit_bytes = context.wg_mem_limit_bytes; uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = context.sg_mat_n; + uint32_t kv_granularity = std::max(1u, context.sg_mat_n); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = std::max(1u, context.max_subgroup_size); + kv_granularity = 1u; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { q_tile = 1u; kv_granularity = 8u; } - const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!key.kv_direct) { - bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + if (limit_bytes <= base_q_bytes) { + return 0; } - if (key.has_mask) { - bytes_per_kv += q_tile; + const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; + if (bytes_per_kv == 0) { + return 0; } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / kv_granularity) * kv_granularity; + const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( @@ -731,14 +754,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool tile_can_dispatch_all_q_rows = + context.max_subgroup_size > 0 && + context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows && !use_vec; decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : @@ -749,7 +776,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( return decisions; } - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); decisions.kv_direct = key.kv_direct; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); // invalidate if even the smallest kv_tile doesn't fit in shared memory @@ -778,21 +805,20 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( std::min(64u, max_kv_tile) : std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::min(std::max(1u, context.max_wg_size), + std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); - decisions.kv_tile = - std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + if (decisions.kv_tile == 0) { + return decisions; } if (decisions.kv_direct) { GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::max(1u, context.max_subgroup_size) : - context.sg_mat_n; + decisions.kv_tile -= + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; } } return decisions; @@ -1577,7 +1603,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1694,10 +1720,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1805,13 +1831,13 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2074,10 +2100,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2194,10 +2220,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2558,7 +2584,7 @@ class ggml_webgpu_shader_lib { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; @@ -2586,6 +2612,30 @@ class ggml_webgpu_shader_lib { } variant += std::string("_") + ggml_type_name(key.kv_type); + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); + } + variant += std::string("_q") + ggml_type_name(key.q_type); + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + if (key.has_mask) { defines.push_back("MASK"); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { @@ -2625,9 +2675,11 @@ class ggml_webgpu_shader_lib { shader_src = wgsl_flash_attn_vec_split; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { shader_src = wgsl_flash_attn_tile; - defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); - variant += "_tile"; + variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + + std::to_string(context.max_subgroup_size); } else { defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); @@ -2677,6 +2729,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.dst_type = context.dst->type; key.wg_size = context.max_wg_size; auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { @@ -2686,6 +2739,18 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn_vec_reduce"; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention vec reduce shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 12f60a9900e..02414bfc8b6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -187,6 +187,7 @@ struct webgpu_capabilities { uint32_t sg_mat_k = 0; uint32_t subgroup_size = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; size_t memset_bytes_per_thread; }; @@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline @@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); @@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( @@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #endif ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + // Runtime subgroup size can be any supported size in this range. Shaders + // that allocate per-lane register arrays must size them for the minimum. + ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize; ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx.dst = const_cast(op); shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( @@ -4040,9 +4048,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4050,9 +4058,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = false; break; } - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 37ea23b80c8..ae8036b9ac5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,12 +1,33 @@ enable f16; enable subgroups; +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 #define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 +#ifndef MIN_SUBGROUP_SIZE +#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE +#endif struct Params { offset_q: u32, @@ -41,13 +62,13 @@ struct Params { m1: f32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #define V K #else -@group(0) @binding(1) var K: array>; -@group(0) @binding(2) var V: array>; +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; #endif #if defined(MASK) && defined(SINKS) @@ -92,17 +113,17 @@ struct Params { #endif #endif -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; const FLOAT_MIN: f32 = -1.0e9; const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; -const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; -const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; +var q_shmem: array; +var kv_shmem: array; var p_shmem: array; @compute @workgroup_size(WG_SIZE) @@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_col = elem_idx % HEAD_DIM_QK; let head_q_row = q_row_start + q_tile_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + q_col] * params.scale, - head_q_row < params.seq_len_q)); + f32(Q[global_q_row_offset + q_col]) * params.scale, + head_q_row < params.seq_len_q); } workgroupBarrier(); @@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = k4.x; - kv_shmem[kv_off + 1u] = k4.y; - kv_shmem[kv_off + 2u] = k4.z; - kv_shmem[kv_off + 3u] = k4.w; + kv_shmem[kv_off + 0u] = f32(k4.x); + kv_shmem[kv_off + 1u] = f32(k4.y); + kv_shmem[kv_off + 2u] = f32(k4.z); + kv_shmem[kv_off + 3u] = f32(k4.w); } workgroupBarrier(); @@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let kv = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); dot_val += dot(qv, kv); } #ifdef LOGIT_SOFTCAP @@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = v4.x; - kv_shmem[kv_off + 1u] = v4.y; - kv_shmem[kv_off + 2u] = v4.z; - kv_shmem[kv_off + 3u] = v4.w; + kv_shmem[kv_off + 0u] = f32(v4.x); + kv_shmem[kv_off + 1u] = f32(v4.y); + kv_shmem[kv_off + 2u] = f32(v4.z); + kv_shmem[kv_off + 3u] = f32(v4.w); } workgroupBarrier(); @@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let v4 = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); acc += p * v4; } out_regs[reg_idx] = acc; @@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } let dst_vec_index = (row_base + chunk * 4u) >> 2u; - dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + dst[dst_vec_index] = vec4(out_regs[reg_idx] * inv_exp_sum); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 9a0de82a56a..1091d744073 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + // Default values #define HEAD_DIM_V 64 #define WG_SIZE 128 @@ -17,7 +23,7 @@ struct Params { }; @group(0) @binding(0) var tmp: array; -@group(0) @binding(1) var dst: array>; +@group(0) @binding(1) var dst: array>; @group(0) @binding(2) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (thread == 0u) { let dst_vec_index = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + dst[dst_vec_index] = vec4(vec4(sum_x, sum_y, sum_z, sum_w) * inv_s); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index b1e234784a8..30ebbebe772 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -8,6 +8,18 @@ enable subgroups; #define KV_TYPE f16 #endif +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 @@ -89,7 +101,7 @@ struct Params { nwg: u32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; @@ -191,41 +203,41 @@ struct Params { @group(0) @binding(BLK_BINDING) var blk: array; #endif @group(0) @binding(TMP_BINDING) var tmp: array; -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); // we can reuse the same shmem for K and V since we only need one at a time -var kv_shmem: array; +var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx]) * params.scale, + inter_shmem[kv_idx] * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -289,10 +301,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load the single Q row into shared memory for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + elem_idx], - q_row_start < params.seq_len_q)); + f32(Q[global_q_row_offset + elem_idx]), + q_row_start < params.seq_len_q); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { @@ -308,7 +320,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let blk_state = blk_state_local; let skip_tile = blk_state == 0u; for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = f16(0.0); + inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory @@ -331,8 +343,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -359,7 +371,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -377,10 +389,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; let vec_idx = (global_k_row_offset + k_col) >> 2u; let k4 = select(vec4(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(k4.x); - kv_shmem[elem_idx + 1u] = f16(k4.y); - kv_shmem[elem_idx + 2u] = f16(k4.z); - kv_shmem[elem_idx + 3u] = f16(k4.w); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); } #endif @@ -401,20 +413,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_off = i * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); #ifdef KV_DIRECT let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); let kv = vec4(K[idx >> 2u]); #else let idx = kv_idx * HEAD_DIM_QK + (i * 4u); let kv = vec4( - f32(kv_shmem[idx + 0u]), - f32(kv_shmem[idx + 1u]), - f32(kv_shmem[idx + 2u]), - f32(kv_shmem[idx + 3u])); + kv_shmem[idx + 0u], + kv_shmem[idx + 1u], + kv_shmem[idx + 2u], + kv_shmem[idx + 3u]); #endif partial_sum += dot(qv, kv); } @@ -435,7 +447,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - inter_shmem[kv_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = sum_bcast; } } } @@ -450,7 +462,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_k_col = kv_tile + elem_idx; let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; let mask_idx = mask_global_offset + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds); } } #else @@ -483,7 +495,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); total_exp_term += subgroupAdd(cur_p); if (kv_idx < KV_TILE) { - inter_shmem[kv_idx] = f16(cur_p); + inter_shmem[kv_idx] = cur_p; } } @@ -493,7 +505,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * cur_exp + total_exp_term; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp; } } @@ -517,8 +529,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -545,7 +557,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -563,10 +575,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; let vec_idx = (global_v_row_offset + v_col) >> 2u; let v4 = select(vec4(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(v4.x); - kv_shmem[elem_idx + 1u] = f16(v4.y); - kv_shmem[elem_idx + 2u] = f16(v4.z); - kv_shmem[elem_idx + 3u] = f16(v4.w); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); } #endif @@ -589,17 +601,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx]); + let p = inter_shmem[kv_idx]; #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); #else let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; let v4 = vec4( - f32(kv_shmem[v_idx + 0u]), - f32(kv_shmem[v_idx + 1u]), - f32(kv_shmem[v_idx + 2u]), - f32(kv_shmem[v_idx + 3u])); + kv_shmem[v_idx + 0u], + kv_shmem[v_idx + 1u], + kv_shmem[v_idx + 2u], + kv_shmem[v_idx + 3u]); #endif lo += p * v4; } @@ -630,10 +642,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); - o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); - o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); - o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); + o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x; + o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y; + o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z; + o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w; } } } @@ -660,7 +672,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * max_exp + sink_exp_sum; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp; } } workgroupBarrier(); @@ -681,7 +693,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = v; + dst[dst_vec_index] = vec4(v); } } else { let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index b8f1bca1284..8e34e1c9ca0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -50,10 +50,25 @@ struct Params { @group(0) @binding(PARAMS_BINDING) var params: Params; +fn erf_approx(x: TYPE) -> TYPE { + let x_f32 = f32(x); + let s = select(-1.0, 1.0, x_f32 >= 0.0); + let ax = abs(x_f32); + + let t = 1.0 / (1.0 + 0.3275911 * ax); + + let y = 1.0 - + (((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t + - 0.284496736) * t + 0.254829592) * t) * + exp(-ax * ax); + + return TYPE(s * y); +} + @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { - return; + return; } var i = gid.x; let ne2 = params.ne2; @@ -71,15 +86,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i1 = i / ne0; let i0 = i % ne0; - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; #ifdef ABS let res = abs(src[params.offset_src + src_idx]); #endif #ifdef SGN - let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), - src[params.offset_src + src_idx] > 0.0); + let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0); #endif #ifdef NEG let res = -src[params.offset_src + src_idx]; @@ -94,8 +107,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef ELU - let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef HARDSIGMOID let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); @@ -120,31 +132,16 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(params.fill_val); #endif #ifdef HARDSWISH - let res = src[params.offset_src + src_idx] * - min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); + let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); #endif #ifdef GELU - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * - (src[params.offset_src + src_idx] + - 0.044715 * pow(src[params.offset_src + src_idx], 3.0)), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913))); #endif #ifdef GELU_QUICK - let res = src[params.offset_src + src_idx] * 0.5 * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0)))); #endif #ifdef GELU_ERF - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476)); #endif #ifdef XIELU let val = f32(src[params.offset_src + src_idx]); From e8a7cd314fccab6dc2db3206a70f6f1a782031c2 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 12 May 2026 23:27:40 +0900 Subject: [PATCH 068/289] ggml-webgpu: Enables running gpt-oss-20b (llama/22906) * Enable to run gpt-oss-20b and refactor mulmat-q * disable test-backend-ops in ubuntu-24-webgpu --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 68 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 61 ++++- ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl | 64 +++++ .../wgsl-shaders/common_decls.tmpl | 7 + .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 21 ++ .../wgsl-shaders/mul_mat_decls.tmpl | 221 +++++++++++------- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 42 ++++ 7 files changed, 392 insertions(+), 92 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 932a01d385e..11701e79433 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -495,6 +495,22 @@ struct ggml_webgpu_binary_pipeline_key_hash { } }; +/* Add_Id */ + +struct ggml_webgpu_add_id_pipeline_key { + bool inplace; + + bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_add_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Unary **/ struct ggml_webgpu_unary_pipeline_key { @@ -1058,7 +1074,9 @@ class ggml_webgpu_shader_lib { std::unordered_map pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap/src_overlap + std::unordered_map + add_id_pipelines; // inplace std::unordered_map concat_pipelines; // type std::unordered_map @@ -1433,6 +1451,7 @@ class ggml_webgpu_shader_lib { case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: { // Quantized types using u32 buffers for portability. defines.push_back("SRC_TYPE=u32"); @@ -1451,6 +1470,7 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_SCALE_MIN"); defines.push_back(type_upper + "_TABLES"); defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_LUT"); variant += "_"; variant += type_str; @@ -1460,7 +1480,7 @@ class ggml_webgpu_shader_lib { if (key.src_type == GGML_TYPE_Q1_0) { defines.push_back("BLOCK_SIZE=128u"); } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -1774,6 +1794,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -1908,6 +1931,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2042,6 +2068,7 @@ class ggml_webgpu_shader_lib { case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: { // Quantized types using u32 buffers for portability. defines.push_back("SRC0_TYPE=u32"); @@ -2169,6 +2196,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2286,6 +2316,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2503,6 +2536,37 @@ class ggml_webgpu_shader_lib { return binary_pipelines[key]; } + webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_add_id_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = add_id_pipelines.find(key); + if (it != add_id_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "add_id"; + const char * shader_src = wgsl_add_id; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(shader_src, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + add_id_pipelines[key] = pipeline; + return pipeline; + } + webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 02414bfc8b6..b24101c78b0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1411,8 +1411,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q3_K: case GGML_TYPE_Q2_K: case GGML_TYPE_Q1_0: - use_fast = true; - break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: @@ -1422,6 +1420,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: use_fast = true; break; default: @@ -2145,6 +2144,56 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_add_id_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src2->nb[0] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + }; + + std::vector entries; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, dst)); + } + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + uint32_t total_wg = ggml_nrows(dst); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2918,6 +2967,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, case GGML_OP_MUL: case GGML_OP_DIV: return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_ADD_ID: + return ggml_webgpu_add_id(ctx, src0, src1, src2, node); case GGML_OP_CONCAT: return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_REPEAT: @@ -3867,6 +3918,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: return true; default: return false; @@ -3905,6 +3957,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1->type == op->type); break; + case GGML_OP_ADD_ID: + supports_op = src0->type == GGML_TYPE_F32; + break; case GGML_OP_CONCAT: supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); break; @@ -3962,6 +4017,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: supports_op = true; break; default: @@ -4001,6 +4057,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: supports_op = true; break; default: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl new file mode 100644 index 00000000000..2573926cb89 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl @@ -0,0 +1,64 @@ +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + nb01: u32, + nb02: u32, + nb11: u32, + nb20: u32, + nb21: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +@group(0) @binding(0) var src0: array; // [n_embd, n_experts_used, n_token] +@group(0) @binding(1) var src1: array; // [n_embd, n_experts] +@group(0) @binding(2) var ids: array; // [n_experts_used, n_token] + +#ifdef INPLACE + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let wg_linear = wg_id.x + wg_id.y * num_wg.x; + + if (wg_linear < params.ne1 * params.ne2) { + let thread_id = local_id.x; + let i2 = wg_linear / params.ne1; + let i1 = wg_linear % params.ne1; + + let i11 = u32(ids[params.offset_ids + i1 * params.nb20 + i2 * params.nb21]); + + let src0_row = params.offset_src0 + i1 * params.nb01 + i2 * params.nb02; + let src1_row = params.offset_src1 + i11 * params.nb11; + let dst_row = params.offset_dst + i1 * params.ne0 + i2 * (params.ne0 * params.ne1); + + for (var i = thread_id;i < params.ne0; i += WG_SIZE) { +#ifdef INPLACE + src0[src0_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#else + dst[dst_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#endif + } + } + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 14c045b0ba6..372ea79bf9d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -896,3 +896,10 @@ const kvalues_iq4nl = array( ); #endif + +#ifdef MXFP4_LUT +const kvalues_mxfp4 = array( + 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12 +); +#endif + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 5710cd35469..78d61a93d28 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -652,6 +652,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif +#ifdef MXFP4 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 17; + let eu8 = get_byte(load_u32_at_src(block_byte_base), 0); + let d = ldexp(1.0, i32(eu8) - 128); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 1 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16u] = q_hi; + } + } +} +#endif + + @group(0) @binding(0) var src: array; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 51cf08f196f..eb2a8368f43 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -100,34 +100,37 @@ const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * 2 + k] = q_lo; - shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -141,35 +144,38 @@ const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; let q_hi = f16((q_byte >> 4) & 0xF) * d + m; - shmem[shmem_idx + j * 2 + k] = q_lo; - shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -178,52 +184,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #endif // INIT_SRC0_SHMEM_Q4_1 #ifdef INIT_SRC0_SHMEM_Q5_0 -// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - let j_adjusted = j + (block_offset / 2u); - - - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; - - shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight - shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -232,54 +235,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #endif // INIT_SRC0_SHMEM_Q5_0 #ifdef INIT_SRC0_SHMEM_Q5_1 -// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); let qh_packed = load_u32_at_src0(block_byte_base + 4u); - for (var j = 0u; j < 2; j++) { - - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - let j_adjusted = j + (block_offset / 2u); - - - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m; - - shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight - shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; + let q_hi = f16(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; + let q_lo = f16((q_byte & 0xF) | qh_lo) * d + m; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -293,33 +291,34 @@ const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); - for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * 2 + k] = q_val; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } } @@ -333,34 +332,35 @@ const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d + m; - shmem[shmem_idx + j * 2 + k] = q_val; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } } @@ -1163,3 +1163,48 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S + +#ifdef INIT_SRC0_SHMEM_MXFP4 +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 17u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_block_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); + let e = ldexp(1.0, i32(eu8) - 128); + + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_MXFP4 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 1f59bd14863..711c7e829d8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -1389,3 +1389,45 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src return acc; } #endif + +#ifdef MUL_ACC_MXFP4 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 17 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); + let e = ldexp(1.0, i32(eu8) - 128); + var row_sum = 0.0; + let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif From 1caed1d2bae10279bc56261c2e03c3a86a27f4bb Mon Sep 17 00:00:00 2001 From: yzyyzyhhh <96101183+happyyzy@users.noreply.github.com> Date: Wed, 13 May 2026 04:10:37 +0800 Subject: [PATCH 069/289] opencl: add opt-in Adreno xmem F16xF32 GEMM for prefill (llama/22755) * ggml-opencl: add Adreno xmem F16xF32 GEMM for prefill * ggml-opencl: address Adreno xmem review comments * ggml-opencl: align xmem gemm kernel naming --------- Co-authored-by: Your Name --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 220 +++++++++++++++++ .../kernels/gemm_xmem_f16_f32_os8.cl | 233 ++++++++++++++++++ 3 files changed, 457 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7edb3eb4e9c..0b39c011371 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -176,6 +176,10 @@ set(GGML_OPENCL_KERNELS flash_attn_f32 ) +if (GGML_OPENCL_USE_ADRENO_KERNELS) + list(APPEND GGML_OPENCL_KERNELS gemm_xmem_f16_f32_os8) +endif () + foreach (K ${GGML_OPENCL_KERNELS}) ggml_opencl_add_kernel(${K}) endforeach() diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 73a58f74a94..61bdc62cd10 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -407,6 +407,8 @@ struct ggml_backend_opencl_context { cl_bool non_uniform_workgroups; size_t image_max_buffer_size; + size_t image2d_max_width; + size_t image2d_max_height; cl_context context; cl_command_queue queue; @@ -420,6 +422,11 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_src0; ggml_cl_buffer prealloc_src1; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + ggml_cl_buffer prealloc_adreno_xmem_const; + bool adreno_xmem_gemm_enabled = false; +#endif + // prealloc buffers for MoE router table preprocess bool toggle_reorder = false; ggml_cl_buffer prealloc_post_router; @@ -538,6 +545,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_f16_f32_tiled; + cl_kernel kernel_adreno_xmem_pack_src_f32; + cl_kernel kernel_adreno_xmem_prepack_weight_f16; + cl_kernel kernel_gemm_xmem_f16_f32_os8; + cl_kernel kernel_adreno_xmem_store_dst_f32; cl_kernel kernel_mul_mm_f16_f32_kqv; cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; @@ -1554,6 +1565,32 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // gemm_xmem_f16_f32_os8 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_xmem_f16_f32_os8.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_xmem_f16_f32_os8.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_adreno_xmem_pack_src_f32 = + clCreateKernel(prog, "adreno_xmem_pack_src_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_prepack_weight_f16 = + clCreateKernel(prog, "adreno_xmem_prepack_weight_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_xmem_f16_f32_os8 = + clCreateKernel(prog, "kernel_gemm_xmem_f16_f32_os8", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_store_dst_f32 = + clCreateKernel(prog, "adreno_xmem_store_dst_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // mul_mm_f32_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3473,6 +3510,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); + clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL); + clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); @@ -3511,6 +3552,16 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + if (getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM %s\n", + backend_ctx->adreno_xmem_gemm_enabled ? + "enabled (temporary weight prepack)" : "requested but unsupported by this driver"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use large buffer for Adreno backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; @@ -9920,6 +9971,169 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( + const ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + const ggml_tensor * dst) { + if (!backend_ctx->adreno_xmem_gemm_enabled) { + return false; + } + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + return false; + } + if (src0->type != GGML_TYPE_F16 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + if (src0->ne[2] != 1 || src0->ne[3] != 1 || + src1->ne[2] != 1 || src1->ne[3] != 1 || + dst->ne[2] != 1 || dst->ne[3] != 1) { + return false; + } + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + if (src1->ne[0] != K || dst->ne[0] != M || dst->ne[1] != N) { + return false; + } + if (N <= 1 || M < 64 || N < 16 || K < 64) { + return false; + } + if ((K % 8) != 0) { + return false; + } + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + if (static_cast(N) > backend_ctx->image2d_max_width || + static_cast(kpack) > backend_ctx->image2d_max_height) { + return false; + } + if (static_cast(N) > backend_ctx->image2d_max_width || + static_cast(npack) > backend_ctx->image2d_max_height) { + return false; + } + return true; +} + +static void ggml_cl_mul_mat_f16_f32_adreno_xmem( + ggml_backend_t backend, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + const cl_ulong offset0 = extra0->offset + src0->view_offs; + const cl_ulong offset1 = extra1->offset + src1->view_offs; + const cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + const int os = 8; + + const size_t xmem_bytes = 6144; + const size_t weight_bytes = static_cast(kpack) * static_cast(npack) * 4u * sizeof(cl_half4); + + backend_ctx->prealloc_adreno_xmem_const.allocate(backend_ctx->context, xmem_bytes); + + cl_int err = CL_SUCCESS; + cl_image_format fmt = {}; + fmt.image_channel_order = CL_RGBA; + fmt.image_channel_data_type = CL_HALF_FLOAT; + + cl_image_desc desc_src = {}; + desc_src.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_src.image_width = static_cast(N); + desc_src.image_height = static_cast(kpack); + cl_mem src_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_src, nullptr, &err); + CL_CHECK(err); + + cl_image_desc desc_dst = {}; + desc_dst.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_dst.image_width = static_cast(N); + desc_dst.image_height = static_cast(npack); + cl_mem dst_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_dst, nullptr, &err); + CL_CHECK(err); + + cl_mem weights = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, weight_bytes, nullptr, &err); + CL_CHECK(err); + + cl_kernel prepack = backend_ctx->kernel_adreno_xmem_prepack_weight_f16; + CL_CHECK(clSetKernelArg(prepack, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(prepack, 1, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(prepack, 2, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(prepack, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(prepack, 4, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(prepack, 5, sizeof(int), &kpack)); + CL_CHECK(clSetKernelArg(prepack, 6, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(prepack, 7, sizeof(int), &os)); + size_t lws = 256; + size_t max_wg = backend_ctx->get_kernel_workgroup_size(prepack); + if (lws > max_wg) { + lws = max_wg; + } + size_t gws = CEIL_DIV(static_cast(kpack) * static_cast(npack), lws) * lws; + backend_ctx->enqueue_ndrange_kernel(prepack, 1, &gws, &lws, dst); + + cl_kernel pack_src = backend_ctx->kernel_adreno_xmem_pack_src_f32; + CL_CHECK(clSetKernelArg(pack_src, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(pack_src, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(pack_src, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(pack_src, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(pack_src, 4, sizeof(int), &N)); + size_t pack_src_lws[2] = { 16, 16 }; + size_t pack_src_gws[2] = { + CEIL_DIV(static_cast(N), pack_src_lws[0])*pack_src_lws[0], + CEIL_DIV(static_cast(kpack), pack_src_lws[1])*pack_src_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(pack_src, 2, pack_src_gws, pack_src_lws, dst); + + cl_kernel gemm = backend_ctx->kernel_gemm_xmem_f16_f32_os8; + CL_CHECK(clSetKernelArg(gemm, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(gemm, 1, sizeof(cl_mem), &backend_ctx->prealloc_adreno_xmem_const.buffer)); + CL_CHECK(clSetKernelArg(gemm, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(gemm, 3, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(gemm, 4, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(gemm, 5, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(gemm, 6, sizeof(int), &kpack)); + const size_t z_values = CEIL_DIV(static_cast(npack), static_cast(os)); + size_t gemm_lws[3] = { 64, 1, 1 }; + size_t gemm_gws[3] = { + z_values*gemm_lws[0], + CEIL_DIV(static_cast(N), gemm_lws[0]), + 1 + }; + backend_ctx->enqueue_ndrange_kernel(gemm, 3, gemm_gws, gemm_lws, dst); + + cl_kernel store_dst = backend_ctx->kernel_adreno_xmem_store_dst_f32; + CL_CHECK(clSetKernelArg(store_dst, 0, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(store_dst, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(store_dst, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(store_dst, 3, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(store_dst, 4, sizeof(int), &N)); + size_t store_lws[2] = { 16, 16 }; + size_t store_gws[2] = { + CEIL_DIV(static_cast(N), store_lws[0])*store_lws[0], + CEIL_DIV(static_cast(npack), store_lws[1])*store_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(store_dst, 2, store_gws, store_lws, dst); + + CL_CHECK(clReleaseMemObject(weights)); + CL_CHECK(clReleaseMemObject(dst_img)); + CL_CHECK(clReleaseMemObject(src_img)); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11681,6 +11895,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } case GGML_TYPE_F16: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (ggml_cl_can_use_adreno_xmem_gemm_f16_f32(backend_ctx, src0, src1, dst)) { + ggml_cl_mul_mat_f16_f32_adreno_xmem(backend, src0, src1, dst); + return; + } +#endif kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) diff --git a/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl new file mode 100644 index 00000000000..df9d9aed067 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl @@ -0,0 +1,233 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load : enable + +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void adreno_xmem_pack_src_f32( + __global const void * src_void, + ulong offset, + __write_only image2d_t src_img, + int K, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int kpack = K / 4; + + if (x >= N || y >= kpack) { + return; + } + + __global const float * src = (__global const float *)((__global const char *)src_void + offset); + const int base = x*K + y*4; + const half4 v = (half4)((half)src[base + 0], (half)src[base + 1], (half)src[base + 2], (half)src[base + 3]); + write_imageh(src_img, (int2)(x, y), v); +} + +__kernel void adreno_xmem_prepack_weight_f16( + __global half4 * dst, + __global const void * src_void, + ulong offset, + int K, + int M, + int kpack, + int npack, + int os) { + const int linear = get_global_id(0); + const int total = kpack*npack; + if (linear >= total) { + return; + } + + __global const half * src = (__global const half *)((__global const char *)src_void + offset); + + const int dst_ogroup = linear % os; + const int dst_o_sp_i = linear / os; + const int dst_i = dst_o_sp_i % kpack; + const int dst_o = dst_o_sp_i / kpack; + const int o_slice = dst_o*os + dst_ogroup; + const int k_base = dst_i*4; + + half4 w0 = (half4)(0.0h); + half4 w1 = (half4)(0.0h); + half4 w2 = (half4)(0.0h); + half4 w3 = (half4)(0.0h); + + const int o0 = o_slice*4 + 0; + const int o1 = o_slice*4 + 1; + const int o2 = o_slice*4 + 2; + const int o3 = o_slice*4 + 3; + + if (k_base + 0 < K) { + if (o0 < M) w0.s0 = src[o0*K + k_base + 0]; + if (o1 < M) w0.s1 = src[o1*K + k_base + 0]; + if (o2 < M) w0.s2 = src[o2*K + k_base + 0]; + if (o3 < M) w0.s3 = src[o3*K + k_base + 0]; + } + if (k_base + 1 < K) { + if (o0 < M) w1.s0 = src[o0*K + k_base + 1]; + if (o1 < M) w1.s1 = src[o1*K + k_base + 1]; + if (o2 < M) w1.s2 = src[o2*K + k_base + 1]; + if (o3 < M) w1.s3 = src[o3*K + k_base + 1]; + } + if (k_base + 2 < K) { + if (o0 < M) w2.s0 = src[o0*K + k_base + 2]; + if (o1 < M) w2.s1 = src[o1*K + k_base + 2]; + if (o2 < M) w2.s2 = src[o2*K + k_base + 2]; + if (o3 < M) w2.s3 = src[o3*K + k_base + 2]; + } + if (k_base + 3 < K) { + if (o0 < M) w3.s0 = src[o0*K + k_base + 3]; + if (o1 < M) w3.s1 = src[o1*K + k_base + 3]; + if (o2 < M) w3.s2 = src[o2*K + k_base + 3]; + if (o3 < M) w3.s3 = src[o3*K + k_base + 3]; + } + + dst[linear*4 + 0] = w0; + dst[linear*4 + 1] = w1; + dst[linear*4 + 2] = w2; + dst[linear*4 + 3] = w3; +} + +__attribute__((qcom_max_concurrent_subgroups(12))) +__kernel void kernel_gemm_xmem_f16_f32_os8( + __constant half8 * weights_buffer __attribute__((sub_group_uniform)), + __constant half8 * xmem_buffer __attribute__((max_constant_size((6144)))), + __read_only image2d_t src_img, + __write_only image2d_t dst_img, + int N, + int npack, + int kpack) { + const int X = get_group_id(1)*get_local_size(0) + get_local_id(0); + const int Z = get_group_id(0)*get_local_size(2) + get_local_id(2); + + if (X >= N || Z*8 >= npack) { + return; + } + + half4 r0 = (half4)(0.0h); + half4 r1 = (half4)(0.0h); + half4 r2 = (half4)(0.0h); + half4 r3 = (half4)(0.0h); + half4 r4 = (half4)(0.0h); + half4 r5 = (half4)(0.0h); + half4 r6 = (half4)(0.0h); + half4 r7 = (half4)(0.0h); + + int f_offset = Z*kpack*32; + int subgroup_id = (int)(0x1F & qcom_get_physical_sub_group_id()); + subgroup_id = subgroup_id % 12; + const int c_offset = subgroup_id*32; + __constant half16 * weights_cache = (__constant half16 *)&xmem_buffer[c_offset]; + + int coord_s = 0; + do { + const half4 src0 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + const half4 src1 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + + qcom_sub_group_constant_load8(xmem_buffer, weights_buffer, c_offset, f_offset >> 1, 32); + f_offset += 64; + qcom_sub_group_sync(QCOM_CLK_CONST_LOAD_SYNC); + + r0 += src0.x * weights_cache[0].s0123; + r0 += src0.y * weights_cache[0].s4567; + r0 += src0.z * weights_cache[0].s89ab; + r0 += src0.w * weights_cache[0].scdef; + r1 += src0.x * weights_cache[1].s0123; + r1 += src0.y * weights_cache[1].s4567; + r1 += src0.z * weights_cache[1].s89ab; + r1 += src0.w * weights_cache[1].scdef; + r2 += src0.x * weights_cache[2].s0123; + r2 += src0.y * weights_cache[2].s4567; + r2 += src0.z * weights_cache[2].s89ab; + r2 += src0.w * weights_cache[2].scdef; + r3 += src0.x * weights_cache[3].s0123; + r3 += src0.y * weights_cache[3].s4567; + r3 += src0.z * weights_cache[3].s89ab; + r3 += src0.w * weights_cache[3].scdef; + r4 += src0.x * weights_cache[4].s0123; + r4 += src0.y * weights_cache[4].s4567; + r4 += src0.z * weights_cache[4].s89ab; + r4 += src0.w * weights_cache[4].scdef; + r5 += src0.x * weights_cache[5].s0123; + r5 += src0.y * weights_cache[5].s4567; + r5 += src0.z * weights_cache[5].s89ab; + r5 += src0.w * weights_cache[5].scdef; + r6 += src0.x * weights_cache[6].s0123; + r6 += src0.y * weights_cache[6].s4567; + r6 += src0.z * weights_cache[6].s89ab; + r6 += src0.w * weights_cache[6].scdef; + r7 += src0.x * weights_cache[7].s0123; + r7 += src0.y * weights_cache[7].s4567; + r7 += src0.z * weights_cache[7].s89ab; + r7 += src0.w * weights_cache[7].scdef; + + r0 += src1.x * weights_cache[8].s0123; + r0 += src1.y * weights_cache[8].s4567; + r0 += src1.z * weights_cache[8].s89ab; + r0 += src1.w * weights_cache[8].scdef; + r1 += src1.x * weights_cache[9].s0123; + r1 += src1.y * weights_cache[9].s4567; + r1 += src1.z * weights_cache[9].s89ab; + r1 += src1.w * weights_cache[9].scdef; + r2 += src1.x * weights_cache[10].s0123; + r2 += src1.y * weights_cache[10].s4567; + r2 += src1.z * weights_cache[10].s89ab; + r2 += src1.w * weights_cache[10].scdef; + r3 += src1.x * weights_cache[11].s0123; + r3 += src1.y * weights_cache[11].s4567; + r3 += src1.z * weights_cache[11].s89ab; + r3 += src1.w * weights_cache[11].scdef; + r4 += src1.x * weights_cache[12].s0123; + r4 += src1.y * weights_cache[12].s4567; + r4 += src1.z * weights_cache[12].s89ab; + r4 += src1.w * weights_cache[12].scdef; + r5 += src1.x * weights_cache[13].s0123; + r5 += src1.y * weights_cache[13].s4567; + r5 += src1.z * weights_cache[13].s89ab; + r5 += src1.w * weights_cache[13].scdef; + r6 += src1.x * weights_cache[14].s0123; + r6 += src1.y * weights_cache[14].s4567; + r6 += src1.z * weights_cache[14].s89ab; + r6 += src1.w * weights_cache[14].scdef; + r7 += src1.x * weights_cache[15].s0123; + r7 += src1.y * weights_cache[15].s4567; + r7 += src1.z * weights_cache[15].s89ab; + r7 += src1.w * weights_cache[15].scdef; + } while (coord_s < kpack); + + int coord_s_out = Z*8; + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r0); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r1); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r2); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r3); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r4); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r5); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r6); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r7); } +} + +__kernel void adreno_xmem_store_dst_f32( + __read_only image2d_t dst_img, + __global void * dst_void, + ulong offset, + int M, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int npack = (M + 3) / 4; + + if (x >= N || y >= npack) { + return; + } + + __global float * dst = (__global float *)((__global char *)dst_void + offset); + const half4 hv = read_imageh(dst_img, smp_zero, (int2)(x, y)); + const int m = y*4; + if (m + 0 < M) dst[x*M + m + 0] = (float)hv.s0; + if (m + 1 < M) dst[x*M + m + 1] = (float)hv.s1; + if (m + 2 < M) dst[x*M + m + 2] = (float)hv.s2; + if (m + 3 < M) dst[x*M + m + 3] = (float)hv.s3; +} From bcaf4498269fea44b47621a4f0e8dff562c93cd4 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Tue, 12 May 2026 19:28:02 -0500 Subject: [PATCH 070/289] hexagon: eliminate scalar VTCM loads via HVX splat helpers (llama/22993) * hexagon: add hvx_vec_repl helpers and use those for splat-from-vtcm usecase * hmx-mm: optimize per-group scale handling * hmx-fa: optimize slope load from vtcm * hmx-fa: use aligned access where possible in hmx-utils * hexagon: add hvx_vec_repl_2x_f16 helper and consolidate repl helpers --------- Co-authored-by: Max Krasnyansky --- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 5 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 29 +++----- ggml/src/ggml-hexagon/htp/hmx-utils.h | 34 ++++----- ggml/src/ggml-hexagon/htp/hvx-repl.h | 74 +++++++++++++++++++ ggml/src/ggml-hexagon/htp/hvx-utils.h | 1 + 5 files changed, 106 insertions(+), 37 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hvx-repl.h diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 8a6d7c14edf..4a4ff0b331d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { // ALiBi slopes — only needed when has_alibi (scheme A) HVX_Vector v_slope0, v_slope1; if (args->has_alibi) { - v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); - v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + HVX_Vector v_s = hvx_vmemu(args->slopes + r); + v_slope0 = hvx_vec_repl_f16(v_s); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero(); } const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 9e8c9966e04..e05ccfd5fc7 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -180,12 +180,10 @@ static int hmx_compute_chunks(size_t vtcm_total, // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( - const uint8_t *packed_32, bool upper_nibbles, - const __fp16 *scale, const HVX_Vector vlut_cvt) { +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); @@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + volatile HVX_Vector vscale = hvx_vmemu(scales_4); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); @@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } @@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const uint8_t *r0 = vtcm_src + row0 * row_stride; const uint8_t *r1 = vtcm_src + row1 * row_stride; - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) - : Q6_V_vzero(); + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index 68f174d6937..f448ee3372a 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -77,16 +77,18 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + assert(hex_is_aligned(p0, 128)); + assert(hex_is_aligned(p1, 128)); + assert(c_byte_step % 128 == 0); if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); tile_base += dst_step; @@ -94,8 +96,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); tile_base += dst_step; @@ -116,16 +117,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); tile_base += dst_step; @@ -133,8 +132,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); tile_base += dst_step; diff --git a/ggml/src/ggml-hexagon/htp/hvx-repl.h b/ggml/src/ggml-hexagon/htp/hvx-repl.h new file mode 100644 index 00000000000..fdc7e6c7d2f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-repl.h @@ -0,0 +1,74 @@ +#ifndef HVX_REPL_H +#define HVX_REPL_H + +#include +#include +#include + +#include "hvx-base.h" + +static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) { + return Q6_V_vdelta_VV(v, hvx_vmem(ctrl)); +} + +static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) { + // vdelta control to replicate first two bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) { + // vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1] + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +#endif // HVX_REPL_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad37331..e0452811ec3 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -5,6 +5,7 @@ #include "hvx-types.h" #include "hvx-copy.h" +#include "hvx-repl.h" #include "hvx-scale.h" #include "hvx-exp.h" #include "hvx-inverse.h" From 8b288f5d966e3b3fb048097e991bdb815e79caf6 Mon Sep 17 00:00:00 2001 From: Sachin Sharma Date: Wed, 13 May 2026 11:43:47 +0530 Subject: [PATCH 071/289] ggml-zendnn : adaptive fallback to CPU backend for small batch sizes (llama/22681) * ggml-zendnn : add runtime env var GGML_ZENDNN_ADAPTIVE_FALLBACK to control adaptive fallback (default: enabled) * ggml-zendnn : restore original fallback logic when adaptive fallback is disabled --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 27 +++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 4f321a25257..f1e4f991fae 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 + GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 2b82c7c1dbb..6a83bb6b1ec 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.dtypes.dst = ggml_to_zendnn_type(); params.num_threads = ctx->n_threads; + zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int 0.0f, // beta C, ldc, // output C[n,m] true, // is_weights_const - {}, // batch_params + batch_params, // batch_params params // params ); @@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm GGML_UNUSED(max_tensor_size); } +static bool ggml_zendnn_adaptive_fallback_enabled() { + static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr || + std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0; + return enabled; +} + static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_NONE: @@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const const int64_t ne10 = inputs->ne[0]; const int64_t ne0 = op->ne[0]; const int64_t ne1 = op->ne[1]; - const int64_t min_batch = 1; - if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) || - ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + + if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) { + return false; + } + + if (ggml_zendnn_adaptive_fallback_enabled()) { + const int64_t K = inputs->ne[0]; + const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]); + const int64_t M = weights->ne[1]; + if(K <= 256 || N <= 128 || M <= 96) { return false; + } } + else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + return false; + } + // MUL_MAT_ID performs best with a moderate number of experts due to its // gather + batched matmul + scatter approach. Future versions will leverage // ZenDNN's grouped_gemm for better scalability with larger expert counts: From cb7d38bf18f763efa2bc794c44e648808aa8064c Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 13 May 2026 06:59:28 -0700 Subject: [PATCH 072/289] hexagon: add unary tanh op (llama/22999) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 ++ ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 22 +++++++++++++++++++++- 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d3c125dbc3d..3d1c9da8329 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2865,6 +2865,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -3335,6 +3336,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_TANH: supp = ggml_hexagon_supported_unary(sess, op); break; case GGML_UNARY_OP_SILU: diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 6203e3848b9..98db864dd42 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -62,6 +62,7 @@ enum htp_op_code { HTP_OP_UNARY_EXP, HTP_OP_UNARY_NEG, HTP_OP_UNARY_SOFTPLUS, + HTP_OP_UNARY_TANH, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fa1e0698f4a..883a31d6163 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_UNARY_SIGMOID: case HTP_OP_UNARY_NEG: case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_TANH: case HTP_OP_L2_NORM: return op_unary(octx); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 26a0e0bd793..d4ae89ee6f0 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -373,6 +373,21 @@ static void l2_norm_f32(const float * restrict src, } } +static void tanh_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_tanh_f32_aa(dst_local, src_local, row_elems); + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -477,6 +492,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_UNARY_SOFTPLUS: softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_TANH: + tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -547,10 +565,12 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_UNARY_SOFTPLUS: op_type = "softplus-f32"; break; + case HTP_OP_UNARY_TANH: + op_type = "tanh-f32"; + break; case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; - default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; From 1cbbd0b6d09147657b51f04bd83e48b8db2f6711 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 14 May 2026 02:22:44 +0900 Subject: [PATCH 073/289] flush the gpu profile timestamp before the queryset is overflowed (llama/22995) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b24101c78b0..401c75c1230 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3148,6 +3148,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } ctx->param_arena.reset(); commands.clear(); +#ifdef GGML_WEBGPU_GPU_PROFILE + // flush before the next batch can overflow the QuerySet + if (ctx->profile_timestamp_query_count + 2 * ctx->global_ctx->command_submit_batch_size >= + WEBGPU_MAX_PROFILE_QUERY_COUNT) { + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); + // reset profile timestamp state + ctx->profile_timestamp_query_count = 0; + profile_pipeline_names.clear(); + } +#endif } node_idx += num_encoded_ops; From b19beb6027f45ae5a413554de05ff6896b3a5540 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 13 May 2026 11:24:33 -0700 Subject: [PATCH 074/289] opencl: fix crash when warming up MoE on Adreno (llama/22876) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 61bdc62cd10..248124c2896 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -13132,7 +13132,7 @@ static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast(ne20), 1}; - size_t histogram_local_size[] = {64, static_cast(ne20), 1}; + size_t histogram_local_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); // Scan From d4a4d87f0ef6511c1e5fec36a3f84d6710f83c33 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Wed, 13 May 2026 11:57:31 -0700 Subject: [PATCH 075/289] opencl: add q5_0 and q5_1 MoE for Adreno (llama/22985) * opencl: add q5_0 moe support * opencl: add q5_1 moe support * opencl: avoid potential leak * opencl: suppress unused var warning when building for non-Adreno --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 1019 +++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 204 ++++ .../kernels/gemm_moe_q5_0_f32_ns.cl | 256 +++++ .../kernels/gemm_moe_q5_1_f32_ns.cl | 258 +++++ .../kernels/gemv_moe_q5_0_f32_ns.cl | 119 ++ .../kernels/gemv_moe_q5_1_f32_ns.cl | 121 ++ 7 files changed, 1914 insertions(+), 67 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0b39c011371..c6aba608736 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -106,6 +106,10 @@ set(GGML_OPENCL_KERNELS gemv_moe_q4_0_f32_ns gemm_moe_q4_1_f32_ns gemv_moe_q4_1_f32_ns + gemm_moe_q5_0_f32_ns + gemv_moe_q5_0_f32_ns + gemm_moe_q5_1_f32_ns + gemv_moe_q5_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 248124c2896..0e511592d53 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -556,6 +556,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -615,6 +617,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; + cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; + cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -973,6 +977,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2995,6 +3003,74 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3852,6 +3928,122 @@ struct ggml_tensor_extra_cl_q4_1 { } }; +struct ggml_tensor_extra_cl_q5_0 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q5_0() { + reset(); + } + + void reset() { + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + + qh_img = nullptr; + d_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q5_1 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q5_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + // qh_img, d_img, and m_img are not currently allocated separately. + // TODO: initialize them for non SMALL_PATH path, or remove them. + qh_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -4506,7 +4698,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, // the quantizations here currently do not - they are only supported by Adreno with certain shapes - if (op->src[0]->type == GGML_TYPE_Q4_1) { + if (op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -4692,6 +4886,18 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { delete e; } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0) { + delete e; + } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4775,6 +4981,36 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q5_0 * ggml_opencl_alloc_temp_tensor_extra_q5_0() { + ggml_tensor_extra_cl_q5_0 * extra; + if (temp_tensor_extras_q5_0.empty()) { + extra = new ggml_tensor_extra_cl_q5_0(); + } else { + extra = temp_tensor_extras_q5_0.back(); + temp_tensor_extras_q5_0.pop_back(); + } + + temp_tensor_extras_q5_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_1 * ggml_opencl_alloc_temp_tensor_extra_q5_1() { + ggml_tensor_extra_cl_q5_1 * extra; + if (temp_tensor_extras_q5_1.empty()) { + extra = new ggml_tensor_extra_cl_q5_1(); + } else { + extra = temp_tensor_extras_q5_1.back(); + temp_tensor_extras_q5_1.pop_back(); + } + + temp_tensor_extras_q5_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -4881,6 +5117,16 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + temp_tensor_extras_q5_0.push_back(e); + } + temp_tensor_extras_q5_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + temp_tensor_extras_q5_1.push_back(e); + } + temp_tensor_extras_q5_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -4923,6 +5169,10 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_q4_0_in_use; std::vector temp_tensor_extras_q4_1; std::vector temp_tensor_extras_q4_1_in_use; + std::vector temp_tensor_extras_q5_0; + std::vector temp_tensor_extras_q5_0_in_use; + std::vector temp_tensor_extras_q5_1; + std::vector temp_tensor_extras_q5_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -5286,17 +5536,18 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } - if (tensor->type == GGML_TYPE_MXFP4) { + if (tensor->type == GGML_TYPE_Q5_0) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + ggml_tensor_extra_cl_q5_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_0(); - size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -5306,40 +5557,48 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. cl_buffer_region region; // Create subbuffer for scales. region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_e; - extra->e = clCreateSubBuffer( + region.size = size_d; + extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); auto previous_origin = region.origin; - // Create subbuffer for quants. - region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // Adreno moe mxfp4 kernel needs special transpose and unshuffling + // Adreno moe q5_0 kernel needs special transpose and unshuffling if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; size_t local_work_size[3] = {64, 2, 1}; @@ -5348,61 +5607,36 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); CL_CHECK(clWaitForEvents(1, &evt)); CL_CHECK(clReleaseMemObject(data_device)); - tensor->extra = extra; // Create image for Q - cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; - cl_image_desc img_desc_q = { + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ggml_nelements(tensor) / 8), 0, 0, 0, 0, 0, 0, 0, - { extra->q } + { extra->qs } }; - extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); tensor->extra = extra; return; } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - - size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[3] = {64, 1, 1}; - - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); - - // Create image for Q - cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; - cl_image_desc img_desc_q = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(ggml_nelements(tensor)/32*2), - 0, 0, 0, 0, 0, 0, 0, - { extra->q } - }; - extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); - tensor->extra = extra; - return; } - if (tensor->type == GGML_TYPE_Q8_0) { + if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); + ggml_tensor_extra_cl_q5_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_1(); size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_m + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -5412,10 +5646,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. cl_buffer_region region; + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. // Create subbuffer for scales. region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_d; @@ -5425,22 +5659,227 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); auto previous_origin = region.origin; - // Create subbuffer for quants. + // Create subbuffer for mins. region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( + region.size = size_m; + extra->m = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + previous_origin = region.origin; - cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q5_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->qs } + }; + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + + size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_e; + extra->e = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + tensor->extra = extra; + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + + size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor)/32*2), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + if (tensor->type == GGML_TYPE_Q8_0) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; cl_event evt; CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); @@ -6109,6 +6548,89 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q5_0) { + ggml_tensor_extra_cl_q5_0 * extra = (ggml_tensor_extra_cl_q5_0 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // TODO: normal q5_0 + (void) extra; + return; + } + if (tensor->type == GGML_TYPE_Q5_1) { + ggml_tensor_extra_cl_q5_1 * extra = (ggml_tensor_extra_cl_q5_1 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // TODO: normal q5_1 + (void) extra; + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; @@ -13209,10 +13731,17 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif + // TODO: general MoE for the following types + (void)extra0_q4_1; + (void)extra0_q5_0; + (void)extra0_q5_1; + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; @@ -13540,8 +14069,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, } else { // for gemm kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; - if (strstr(src0->name, "as") != NULL) { + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; @@ -13649,6 +14181,359 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, } return; } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_1_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } #endif //GGML_OPENCL_USE_ADRENO_KERNELS } case GGML_TYPE_Q8_0: { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 5bbf09710f9..8f06d570587 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -56,6 +56,25 @@ struct block_q4_1 { uchar qs[QK4_1 / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q5_0 +//------------------------------------------------------------------------------ +struct block_q5_0 { + half d; // delta + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_0 / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_1 +//------------------------------------------------------------------------------ +struct block_q5_1 { + half d; // delta + half m; // min + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q4_k //------------------------------------------------------------------------------ @@ -460,6 +479,191 @@ kernel void kernel_restore_block_q4_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q5_0_trans4_ns( + __global struct block_q5_0 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_0_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global struct block_q5_0 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + +kernel void kernel_convert_block_q5_1_trans4_ns( + __global struct block_q5_1 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_1_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_m, + __global struct block_q5_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + b->m = src_m[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_1 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..3524cb1bdbd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -0,0 +1,256 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_0(qs5x16, qh5x16, a_f16, scale) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) - 16) * scale; \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) - 16) * scale; \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) - 16) * scale; \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) - 16) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_0_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q5_0 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.lo, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.hi, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..5fc2a523234 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -0,0 +1,258 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_1(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m); \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m); \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_1_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q5_1 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + half m = src0_m[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.lo, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.hi, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..938054cf982 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_0_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) - 16); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) - 16); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) - 16); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) - 16); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) - 16); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) - 16); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) - 16); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) - 16); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_0_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..f33a4ef2757 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -0,0 +1,121 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_1_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_1_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From 97371e928560fffe4d1309202dcdc6290067269c Mon Sep 17 00:00:00 2001 From: scutler-nv Date: Wed, 13 May 2026 13:36:14 -0700 Subject: [PATCH 076/289] Fix for issue #22974. Cast intermediate results to float before adding and casting the result to the destination type. Avoids half+half operator ambiguity. (llama/22994) --- ggml/src/ggml-cuda/allreduce.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 434689abd95..d56129a227e 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -184,13 +184,15 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + recvbuf[off + k] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k])); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); - recvbuf[tail + tid] = - ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + recvbuf[tail + tid] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + + ggml_cuda_cast(host_other[tail + tid])); } } } @@ -210,7 +212,8 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + dst[i] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i])); } } From e4ce42e55f325c83b4e8fb4f848cd3fa086988c0 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 13 May 2026 15:12:40 -0700 Subject: [PATCH 077/289] ggml-webgpu: only use subgroup-matrix path when head dims are divisible by sg_mat_k / sg_mat_n (llama/23020) --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 11701e79433..62a523365b9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -777,7 +777,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + const bool use_subgroup_matrix = + context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && + context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; + const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && @@ -785,7 +788,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : GGML_WEBGPU_FLASH_ATTN_PATH_NONE; if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { From 69500f5502bf508e3d0f350734e9158f59aab1d6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 11:53:30 +0300 Subject: [PATCH 078/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 15685a0718f..5a605ba344e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -628249b398293fc8d2fa81a449ae2920a02c6523 +57ea0bc119d722d74594196cc5b494a34dd87be4 From 46ca43d6399fdeada1b49fb2126ba373bd9ebc38 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 11:53:43 +0300 Subject: [PATCH 079/289] talk-llama : sync llama.cpp --- examples/talk-llama/llama-context.cpp | 27 ++++++++++++++++++++++----- examples/talk-llama/llama.h | 2 ++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 71a59395eb2..3d9714ab166 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2475,11 +2475,29 @@ class llama_io_write_device : public llama_io_write_i { } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2559,8 +2577,7 @@ class llama_io_read_device : public llama_io_read_i { mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 2ea226726ad..308e8ba9dbd 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -858,6 +858,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 From 968eebe77225d25e57a3f981da7c696310f0e881 Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Fri, 15 May 2026 14:03:17 +0200 Subject: [PATCH 080/289] server: add support for carry_initial_prompt (#3781) * Add support for carry_initial_prompt on the server * Update README --- examples/server/README.md | 3 + examples/server/server.cpp | 287 +++++++++++++++++++------------------ 2 files changed, 147 insertions(+), 143 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index ffba5f4edf5..8d4c802b8bf 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -40,6 +40,7 @@ options: -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) -dl, --detect-language [false ] exit after automatically detecting language --prompt PROMPT [ ] initial prompt + --carry-initial-prompt [false ] always prepend initial prompt -m FNAME, --model FNAME [models/ggml-base.en.bin] model path -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -dtw MODEL --dtw MODEL [ ] compute token-level timestamps @@ -78,6 +79,8 @@ curl 127.0.0.1:8080/inference \ -F file="@" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ +-F prompt="" \ +-F carry_initial_prompt="true" \ -F response_format="json" ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 735255b6290..afc95176ec8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -56,11 +56,11 @@ inline void signal_handler(int signal) { struct server_params { - std::string hostname = "127.0.0.1"; - std::string public_path = "examples/server/public"; - std::string request_path = ""; + std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; + std::string request_path = ""; std::string inference_path = "/inference"; - std::string tmp_dir = "."; + std::string tmp_dir = "."; int32_t port = 8080; int32_t read_timeout = 600; @@ -89,49 +89,45 @@ struct whisper_params { float temperature_inc = 0.20f; float no_speech_thold = 0.6f; - bool debug_mode = false; - bool translate = false; - bool detect_language = false; - bool diarize = false; - bool tinydiarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool print_special = false; - bool print_colors = false; - bool print_realtime = false; - bool print_progress = false; - bool no_timestamps = false; - bool token_timestamps = true; - bool use_gpu = true; - bool flash_attn = true; - int32_t gpu_device = 0; - bool suppress_nst = false; - bool no_context = true; + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool token_timestamps = true; + bool use_gpu = true; + bool flash_attn = true; + int32_t gpu_device = 0; + bool suppress_nst = false; + bool no_context = true; bool no_language_probabilities = false; - - std::string language = "en"; - std::string prompt = ""; - std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; - - std::string response_format = json_format; - - // [TDRZ] speaker turn string - std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line - + bool carry_initial_prompt = false; + + std::string language = "en"; + std::string prompt = ""; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string response_format = json_format; + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line std::string openvino_encode_device = "CPU"; - - std::string dtw = ""; + std::string dtw = ""; // Voice Activity Detection (VAD) parameters - bool vad = false; - std::string vad_model = ""; - float vad_threshold = 0.5f; - int vad_min_speech_duration_ms = 250; + bool vad = false; + std::string vad_model = ""; + float vad_threshold = 0.5f; + int vad_min_speech_duration_ms = 250; int vad_min_silence_duration_ms = 100; - float vad_max_speech_duration_s = FLT_MAX; - int vad_speech_pad_ms = 30; - float vad_samples_overlap = 0.1f; + float vad_max_speech_duration_s = FLT_MAX; + int vad_speech_pad_ms = 30; + float vad_samples_overlap = 0.1f; }; void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) { @@ -139,51 +135,52 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] \n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); // server params - fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); - fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); - fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); - fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); - fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); - fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); - fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); - fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); - fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); - fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); + fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); + fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); + fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); + fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); @@ -191,10 +188,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str()); fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold); fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms); - fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); - fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? - std::string("FLT_MAX").c_str() : - std::to_string(params.vad_max_speech_duration_s).c_str()); + fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); + fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? std::string("FLT_MAX").c_str() : std::to_string(params.vad_max_speech_duration_s).c_str()); fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms); fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap); fprintf(stderr, "\n"); @@ -212,63 +207,64 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve whisper_print_usage(argc, argv, params, sparams); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } - else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } - else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-debug" || arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params - else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } - else if ( arg == "--host") { sparams.hostname = argv[++i]; } - else if ( arg == "--public") { sparams.public_path = argv[++i]; } - else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } - else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } - else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } - else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } + else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } + else if ( arg == "--host") { sparams.hostname = argv[++i]; } + else if ( arg == "--public") { sparams.public_path = argv[++i]; } + else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } + else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } + else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } + else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } // Voice Activity Detection (VAD) - else if ( arg == "--vad") { params.vad = true; } - else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } - else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } - else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } - else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } - else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } + else if ( arg == "--vad") { params.vad = true; } + else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } + else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } + else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } + else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } + else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params, sparams); @@ -573,6 +569,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.prompt = req.get_file_value("prompt").content; } + if (req.has_file("carry_initial_prompt")) + { + params.carry_initial_prompt = parse_str_to_bool(req.get_file_value("carry_initial_prompt").content); + } if (req.has_file("response_format")) { params.response_format = req.get_file_value("response_format").content; @@ -940,6 +940,7 @@ int main(int argc, char ** argv) { wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.initial_prompt = params.prompt.c_str(); + wparams.carry_initial_prompt = params.carry_initial_prompt; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; From 6227a0ef739a78312d96e6f8f85e7b6d63683445 Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Mon, 18 May 2026 09:18:04 +0200 Subject: [PATCH 081/289] server : Return speaker information in JSON (#3782) --- examples/server/server.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index afc95176ec8..590378b725f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -315,10 +315,10 @@ std::string generate_temp_filename(const std::string &path, const std::string &p return ss.str(); } -bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) { +bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) { std::ostringstream cmd_stream; std::string converted_filename_temp = temp_filename + "_temp.wav"; - cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; + cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; std::string cmd = cmd_stream.str(); int status = std::system(cmd.c_str()); @@ -341,7 +341,7 @@ bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) return true; } -std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { +std::string estimate_diarization_speaker(const std::vector> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -451,7 +451,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector> pcmf32s) { +std::string output_str(struct whisper_context * ctx, const whisper_params & params, const std::vector> & pcmf32s) { std::stringstream result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { @@ -848,7 +848,7 @@ int main(int argc, char ** argv) { temp_file.close(); std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; - const bool is_converted = convert_to_wav(temp_filename, error_resp); + const bool is_converted = convert_to_wav(temp_filename, error_resp, params.diarize); if (!is_converted) { res.status = 500; res.set_content(error_resp, "application/json"); @@ -1091,6 +1091,14 @@ int main(int argc, char ** argv) { segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; } + if (params.diarize && pcmf32s.size() == 2) { + segment["speaker"] = estimate_diarization_speaker( + pcmf32s, + whisper_full_get_segment_t0(ctx, i), + whisper_full_get_segment_t1(ctx, i), + true); + } + float total_logprob = 0; const int n_tokens = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n_tokens; ++j) { From 47b9eb37a33c5031a1b667ace64477330b9f36c1 Mon Sep 17 00:00:00 2001 From: petterreinholdtsen Date: Mon, 18 May 2026 12:16:39 +0200 Subject: [PATCH 082/289] examples : fix memory leak in read_audio_data (#3810) This commit addresses a memory leak in the `read_audio_data` function where it is currently possible that a call to `ma_decoder_init_file` succeeds and the function returns early without calling `ma_decoder_uninit`. A similar situation can occur with `ma_decoder_init_memory`. Refs: https://bugs.debian.org/1124796 Co-authored-by: Daniel Bevenius --- examples/common-whisper.cpp | 55 +++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 6218a882eb5..977527a0ca5 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -44,7 +44,18 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: ma_result result; ma_decoder_config decoder_config; - ma_decoder decoder; + + struct decoder_guard { + ma_decoder decoder; + bool initialized = false; + ma_decoder * operator&() { return &decoder; } + ~decoder_guard() { + if (initialized) { + ma_decoder_uninit(&decoder); + } + } + }; + decoder_guard decoder{}; decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); @@ -63,32 +74,36 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: audio_data.insert(audio_data.end(), buf, buf + n); } - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - return false; } + decoder.initialized = true; fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); } - else if (((result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder)) != MA_SUCCESS)) { + else { + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } #if defined(WHISPER_FFMPEG) - if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); - - return false; - } - - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - - return false; - } + if (!decoder.initialized) { + if (ffmpeg_decode_audio(fname, audio_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); + return false; + } + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { + fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); + return false; + } + decoder.initialized = true; + } #else - if ((result = ma_decoder_init_memory(fname.c_str(), fname.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - + if (!decoder.initialized) { + fprintf(stderr, "error: failed to read audio data from (%s)\n", fname.c_str()); return false; } #endif @@ -128,8 +143,6 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: } } - ma_decoder_uninit(&decoder); - return true; } From afa2ea544fb4b0448916b4a31ecd33c8685bd482 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 19 May 2026 08:58:43 +0200 Subject: [PATCH 083/289] whisper : set bench data for each iteration (#3812) * whisper : set bench data for each iteration This commit updates whisper_bench_ggml_mul_mat_str to intialize the tensors data for each iteration. The motivation for this is that is currently possible for a previous run's results, F32 values, to leak into the next run. When it is time for the F16 iteration then F32 results can cause NaN values to appear in the tensor values causing the F16 iteration to fail. Refs:https://github.com/ggml-org/whisper.cpp/actions/runs/25901678402/job/76152894644?pr=3735 * ci : set GGML_NATIVE=OFF if x86_64 This commit sets GGML_NATIVE=OFF for x86_64 architectures. The motivation for this is to try to get CI to pass and the theory is that the libggml-cpu.so library in the ccache might have been built by a runner that supports a different instruction set. When another runner that does not support that instruction set tries to use it, it will fail with a segmentation fault. I'm not sure about this yet but going to try this out and if it does not work I'll ssh into the runner to debug further. --- ci/run.sh | 4 ++++ src/whisper.cpp | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index cbe28442e16..b03fdf1c6b1 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -50,6 +50,10 @@ fi CMAKE_EXTRA="-DWHISPER_FATAL_WARNINGS=ON" +if [[ "$(uname -m)" == "x86_64" ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF" +fi + if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" fi diff --git a/src/whisper.cpp b/src/whisper.cpp index 210ca597fb4..0fe29a4541e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -8258,9 +8258,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead()); - // put a bunch of random data in the buffer - for (size_t i = 0; i < buf.size(); i++) buf[i] = i; - for (int j = 0; j < (int) sizes.size(); j++) { int n_q4_0 = 0; int n_q4_1 = 0; @@ -8304,6 +8301,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + // set tensor data after allocation so previous iteration results don't corrupt it. + { + uint8_t * a_data = (uint8_t *) a->data; + for (size_t ii = 0; ii < ggml_nbytes(a); ii++) a_data[ii] = ii & 0x3F; + + uint8_t * b_data = (uint8_t *) b->data; + for (size_t ii = 0; ii < ggml_nbytes(b); ii++) b_data[ii] = ii & 0x3F; + } + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); struct ggml_cgraph * gf = ggml_new_graph(ctx0); From 8443cf05e3fa8ce1b32348e1bcbcf8fc31f7f3ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 21 May 2026 10:59:58 +0200 Subject: [PATCH 084/289] ci : use github ubuntu-22.04-arm runner instead of qemu (#3815) * ci : use github ubuntu-22.04-arm runner instead of qemu This commit updates the ubuntu-22-gcc-arm64 job to use a arm github runner instead of QEMU. The motivation for this is that we get intermittent failure specifically related to QEMU. For example: ```console Segmentation fault (core dumped) qemu: uncaught target signal 11 (Segmentation fault) - core dumped Segmentation fault (core dumped) dpkg: error processing package libc-bin (--configure): installed libc-bin package post-installation script subprocess returned error exit status 139 Processing triggers for ca-certificates (20240203~22.04.1) ... Updating certificates in /etc/ssl/certs... 0 added, 0 removed; done. Running hooks in /etc/ca-certificates/update.d... done. Errors were encountered while processing: libc-bin E: Sub-process /usr/bin/dpkg returned an error code (1) ``` This is an attempt to try to avoid QEMU and hence avoid this issue. * ci : remove QEMU where possible --- .github/workflows/build.yml | 122 +++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..7ace04e1207 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -150,34 +150,21 @@ jobs: ubuntu-22-arm64: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm64] + runs-on: ubuntu-22.04-arm steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} + - name: Install dependencies run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sudo apt-get update + sudo apt-get install -y build-essential libsdl2-dev cmake git - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - cmake --build build --config Release -j $(nproc)' + - name: Build + run: | + cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a + cmake --build build --config Release -j $(nproc) ubuntu-22-arm-v7: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || @@ -305,36 +292,34 @@ jobs: ubuntu-22-gcc-arm64: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-22.04-arm strategy: fail-fast: false matrix: build: [Debug, Release] - arch: [linux/arm64] steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libsdl2-dev git - - name: Build ${{ matrix.arch }} + - name: Configure CMake run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + cmake . \ + -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure' + - name: Build and Test + run: | + make + ctest -L gh --output-on-failure ubuntu-22-gcc-arm-v7: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || @@ -382,7 +367,7 @@ jobs: #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] # TODO: arm/v7 disabled due to clang bug # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/arm64, linux/ppc64le] + arch: [linux/amd64, linux/ppc64le] steps: - name: Clone @@ -407,6 +392,36 @@ jobs: make ctest -L gh --output-on-failure' + ubuntu-22-clang-arm64: + if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || + github.event.inputs.run_type == 'full-ci' }} + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang build-essential cmake libsdl2-dev git + + - name: Build and Test + run: | + cmake . -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + make + ctest -L gh --output-on-failure + ubuntu-22-gcc-sanitized: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} @@ -416,32 +431,23 @@ jobs: fail-fast: false matrix: sanitizer: [ADDRESS, THREAD, UNDEFINED] - arch: [linux/amd64] steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} + - name: Install dependencies run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sudo apt-get update + sudo apt-get install -y build-essential cmake git - apt update - apt install -y build-essential cmake git - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - ctest -L gh --output-on-failure' + - name: Build and Test + run: | + cmake . -DCMAKE_BUILD_TYPE=Debug \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + ctest -L gh --output-on-failure ubuntu-22-cmake-sycl: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || From 0ccd896f5b882628e1c077f9769735ef4ce52860 Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 22 May 2026 08:27:35 +0200 Subject: [PATCH 085/289] common : fix server /inference fails to decode in-memory audio (regression) (#3818) * common: add memory buffer overload of read_audio_data whisper-server /inference without --convert passed the uploaded file bytes to read_audio_data as a filename, so ma_decoder_init_file tried to open a path starting with "RIFF" and failed. every request returned HTTP 400 "Invalid request" on builds without WHISPER_FFMPEG, which is the default. factor the PCM extraction into a shared helper and add an overload that decodes straight from a memory buffer via ma_decoder_init_memory, which the function already used for the stdin path. server now calls it with the upload content. the filename overload behavior is unchanged. --- examples/common-whisper.cpp | 79 ++++++++++++++++++++++--------------- examples/common-whisper.h | 8 ++++ examples/server/server.cpp | 3 +- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 977527a0ca5..d29166b50d8 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -39,6 +39,42 @@ extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); #endif +// extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split +static bool read_audio_from_decoder(ma_decoder & decoder, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { + ma_result result; + ma_uint64 frame_count; + ma_uint64 frames_read; + + if ((result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count)) != MA_SUCCESS) { + fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); + return false; + } + + pcmf32.resize(stereo ? frame_count*2 : frame_count); + + if ((result = ma_decoder_read_pcm_frames(&decoder, pcmf32.data(), frame_count, &frames_read)) != MA_SUCCESS) { + fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); + return false; + } + + if (stereo) { + std::vector stereo_data = pcmf32; + pcmf32.resize(frame_count); + for (uint64_t i = 0; i < frame_count; i++) { + pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]); + } + pcmf32s.resize(2); + pcmf32s[0].resize(frame_count); + pcmf32s[1].resize(frame_count); + for (uint64_t i = 0; i < frame_count; i++) { + pcmf32s[0][i] = stereo_data[2*i]; + pcmf32s[1][i] = stereo_data[2*i + 1]; + } + } + + return true; +} + bool read_audio_data(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { std::vector audio_data; // used for pipe input from stdin or ffmpeg decoding output @@ -109,41 +145,22 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: #endif } - ma_uint64 frame_count; - ma_uint64 frames_read; - - if ((result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); - - return false; - } - - pcmf32.resize(stereo ? frame_count*2 : frame_count); - - if ((result = ma_decoder_read_pcm_frames(&decoder, pcmf32.data(), frame_count, &frames_read)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); - - return false; - } - - if (stereo) { - std::vector stereo_data = pcmf32; - pcmf32.resize(frame_count); + return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); +} - for (uint64_t i = 0; i < frame_count; i++) { - pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]); - } +// decode audio bytes already held in memory +bool read_audio_data(const char * buffer, size_t buffer_size, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { + ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); + ma_decoder decoder; - pcmf32s.resize(2); - pcmf32s[0].resize(frame_count); - pcmf32s[1].resize(frame_count); - for (uint64_t i = 0; i < frame_count; i++) { - pcmf32s[0][i] = stereo_data[2*i]; - pcmf32s[1][i] = stereo_data[2*i + 1]; - } + if (ma_decoder_init_memory(buffer, buffer_size, &decoder_config, &decoder) != MA_SUCCESS) { + fprintf(stderr, "error: failed to decode audio data from memory buffer\n"); + return false; } - return true; + bool ok = read_audio_from_decoder(decoder, pcmf32, pcmf32s, stereo); + ma_decoder_uninit(&decoder); + return ok; } // 500 -> 00:05.000 diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 4134362150a..8714c381046 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -14,6 +14,14 @@ bool read_audio_data( std::vector> & pcmf32s, bool stereo); +// decode audio bytes already held in memory (uploaded file, network buffer) +bool read_audio_data( + const char * buffer, + size_t buffer_size, + std::vector & pcmf32, + std::vector> & pcmf32s, + bool stereo); + // convert timestamp to string, 6000 -> 01:00.000 std::string to_timestamp(int64_t t, bool comma = false); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 590378b725f..aae74c3d840 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -868,8 +868,7 @@ int main(int argc, char ** argv) { // remove temp file std::remove(temp_filename.c_str()); } else { - if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, params.diarize)) - { + if (!::read_audio_data(audio_file.content.data(), audio_file.content.size(), pcmf32, pcmf32s, params.diarize)) { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; res.status = 400; From b3877e10c0a8435c53209b2ebfbaa8359f7baaae Mon Sep 17 00:00:00 2001 From: OrbisAI Security Date: Mon, 25 May 2026 11:49:23 +0530 Subject: [PATCH 086/289] fix: in bindings/ruby/test/jfk_reader/jfk_reader in jfk_reader.c (#3756) * fix: V-002 security vulnerability Automated security fix generated by Orbis Security AI * fix(ruby): use Ruby allocator macros in jfk_reader and fix memory leak - Replace calloc/free with ALLOC_N/xfree to match Ruby binding conventions (ALLOC_N handles overflow checking and raises NoMemoryError on failure) - Free temporary samples buffer after conversion loop (was leaked) - Add NULL check for fopen return value with rb_raise - Add comment clarifying n_samples is a compile-time constant Co-Authored-By: Claude Opus 4.6 * fix(ruby): return false instead of rb_raise in memory_view callback rb_memory_view_get_func_t callbacks should communicate errors via return value (false), not exceptions. rb_memory_view_get has no exception-handling wrapper around get_func calls. Co-Authored-By: Claude Opus 4.6 * replacing ALLOC_N with rb_protect as ALLOC_N raises Ruby exceptions --------- Co-authored-by: Claude Opus 4.6 --- bindings/ruby/test/jfk_reader/jfk_reader.c | 57 +++++++++++++++++++--- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/bindings/ruby/test/jfk_reader/jfk_reader.c b/bindings/ruby/test/jfk_reader/jfk_reader.c index 6657176e767..62207aaa411 100644 --- a/bindings/ruby/test/jfk_reader/jfk_reader.c +++ b/bindings/ruby/test/jfk_reader/jfk_reader.c @@ -2,6 +2,24 @@ #include #include +typedef struct { + VALUE audio_path; + int n_samples; + const char *audio_path_str; + float *data; + short *samples; +} jfk_alloc_args; + +static VALUE +jfk_reader_alloc_resources(VALUE arg) +{ + jfk_alloc_args *a = (jfk_alloc_args *)arg; + a->audio_path_str = StringValueCStr(a->audio_path); + a->data = ALLOC_N(float, a->n_samples); + a->samples = ALLOC_N(short, a->n_samples); + return Qnil; +} + static VALUE jfk_reader_initialize(VALUE self, VALUE audio_path) { @@ -13,21 +31,42 @@ static bool jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) { VALUE audio_path = rb_iv_get(obj, "audio_path"); - const char *audio_path_str = StringValueCStr(audio_path); + // n_samples is a fixed constant (not derived from user input). const int n_samples = 176000; - float *data = (float *)malloc(n_samples * sizeof(float)); - short *samples = (short *)malloc(n_samples * sizeof(short)); - FILE *file = fopen(audio_path_str, "rb"); + + jfk_alloc_args args = { + .audio_path = audio_path, + .n_samples = n_samples, + .audio_path_str = NULL, + .data = NULL, + .samples = NULL, + }; + + int state; + rb_protect(jfk_reader_alloc_resources, (VALUE)&args, &state); + if (state) { + if (args.samples) xfree(args.samples); + if (args.data) xfree(args.data); + return false; + } + + FILE *file = fopen(args.audio_path_str, "rb"); + if (file == NULL) { + xfree(args.samples); + xfree(args.data); + return false; + } fseek(file, 78, SEEK_SET); - fread(samples, sizeof(short), n_samples, file); + fread(args.samples, sizeof(short), n_samples, file); fclose(file); for (int i = 0; i < n_samples; i++) { - data[i] = samples[i]/32768.0; + args.data[i] = args.samples[i] / 32768.0; } + xfree(args.samples); view->obj = obj; - view->data = (void *)data; + view->data = (void *)args.data; view->byte_size = sizeof(float) * n_samples; view->readonly = true; view->format = "f"; @@ -45,6 +84,10 @@ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) static bool jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view) { + if (view->data) { + xfree(view->data); + view->data = NULL; + } return true; } From e414ecf67424f0cd69a3520f99439122ce9aaa1f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 25 May 2026 11:25:15 +0200 Subject: [PATCH 087/289] cmake : add CMakePresets.json [no ci] (#3808) This commit adds a CMakePresets.json file similar to the one in llama.cpp. The motivation for this is that this provides sharable named configuration which can be used with cmake --preset . It also allows for extendins these preset with a CMakeUserPresets.json for specific hardware (like CPUs), architectures, and toolchains etc. --- .gitignore | 1 + CMakePresets.json | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 CMakePresets.json diff --git a/.gitignore b/.gitignore index 957eeb75456..6eb8ff45915 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .DS_Store .vimspector.json /CMakeSettings.json +/CMakeUserPresets.json /talk-llama.dSYM/ build/ diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 00000000000..b5afeb3c0f2 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,95 @@ +{ + "version": 4, + "configurePresets": [ + { + "name": "base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { + "name": "sycl-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "cl", + "GGML_SYCL": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "x64-windows-llvm", "hidden": true, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-windows-llvm", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, + { + "name": "x64-linux-gcc", "hidden": true, + "cacheVariables": { + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++" + } + }, + { "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] }, + { "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] }, + { "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] }, + { "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] }, + + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, + { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, + { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, + { "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, + + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } + ] +} From 06cfc3653b256e4d6e02553e70ddce8bbc625ede Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Thu, 14 May 2026 01:39:14 -0400 Subject: [PATCH 088/289] SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations (llama/21597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations Replace sycl::malloc_device with zeMemAllocDevice for GPU memory allocation in the SYCL backend. sycl::malloc_device triggers the xe kernel driver's DMA-buf/TTM path which mirrors every VRAM allocation 1:1 in system RAM. zeMemAllocDevice uses the SVM/P2P path with no host staging. On a dual Intel Arc Pro B70 system (64GB VRAM, 64GB RAM), a 15.6 GiB model consumed 60 GiB of system RAM via sycl::malloc_device, causing OOM crashes. With zeMemAllocDevice, the same workload uses ~6.7 GiB of system RAM with no performance regression. All Level Zero calls include automatic fallback to the original SYCL allocation path if Level Zero interop is unavailable. * SYCL: address review feedback - remove try/catch, check device types, deduplicate - Remove try/catch from malloc/free/memcpy helpers, check backend and device type upfront instead (ggml_sycl_is_level_zero, ggml_sycl_is_dgpu) - Move shared helpers (is_level_zero, is_dgpu, free_device) to common.cpp and declare in common.hpp to eliminate code duplication - Use SYCL_CHECK(CHECK_TRY_ERROR()) for fallback sycl::free calls - Guard dev2dev_memcpy L0 path to dGPU-to-dGPU only, preserving the host-staged path for iGPU-to-dGPU transfers - Add Windows Level Zero SDK path detection (LEVEL_ZERO_V1_SDK_PATH) in CMakeLists.txt (co-authored with @arthw) * SYCL: add build/runtime flags for Level Zero, address review feedback Implements the architecture suggested by @arthw: compile-time and runtime flags to cleanly separate Level Zero and SYCL memory API paths. - Add GGML_SYCL_SUPPORT_LEVEL_ZERO cmake option (default ON). All Level Zero code is wrapped in #ifdef so the build works on systems without the Level Zero SDK installed (e.g. CPU-only CI servers). Both the loader library and headers are checked before enabling. - Add GGML_SYCL_ENABLE_LEVEL_ZERO runtime env var (default 1). Controls whether Level Zero or SYCL memory APIs are used. Only one API style is used per session, no mixing. If Level Zero is enabled but the devices don't support the Level Zero backend, it auto-disables with a warning. - Remove Level Zero code from dpct_malloc. It was unused (dpct::device_memory is not called anywhere in the backend) and used try/catch for flow control. - Update SYCL.md with documentation for both new parameters. Tested on Intel Arc Pro B70 (32GB), single-GPU and dual-GPU, with both GGML_SYCL_SUPPORT_LEVEL_ZERO=ON and OFF builds. AI-assisted development (Claude). Code reviewed and tested on my hardware. * SYCL: unify Level Zero malloc/free call sites, address review feedback Move ggml_sycl_malloc_device to common.cpp alongside ggml_sycl_free_device. Both functions are now unconditionally available — Level Zero code is #ifdef'd inside the functions, not at call sites. All call sites use uniform SYCL_CHECK(CHECK_TRY_ERROR()) wrapping with no #ifdef blocks. Addresses arthw's review: wrap all malloc/free in SYCL_CHECK for stack traces on failure, eliminate duplicated #ifdef/else patterns at 6 call sites (-29 lines net). Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: add Level Zero SDK to CI, fix device check and missed alloc paths Add Level Zero SDK installation to Ubuntu and Windows SYCL CI jobs so the Level Zero code path is compiled and tested in CI. Fix two bugs found during extended dual-GPU testing (no ONEAPI_DEVICE_SELECTOR set): - The Level Zero backend check was iterating all SYCL devices including CPU. The OpenCL CPU device caused Level Zero to be disabled for the GPUs, defeating the fix on multi-GPU systems. Added is_gpu() filter so only GPU devices are checked. - sycl_ext_malloc_device/sycl_ext_free (tensor reorder temp buffers) were still calling sycl::malloc/sycl::free directly, bypassing the Level Zero path. Routed through ggml_sycl_malloc_device/free_device for consistency with the other device memory call sites. Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: address arthw review feedback on Level Zero memory API structure - Move ggml_sycl_malloc_device to static function in ggml-sycl.cpp; only ggml_sycl_free_device (used by common.cpp) stays in common.cpp - Switch both helpers to use g_ggml_sycl_enable_level_zero global instead of per-call queue backend checks - Remove #ifdef wrapper from global definition; always declare at 0, add #else branch in init block so it stays 0 when L0 not compiled in - Update init loop comment to explain GPU-only device check - CMakeLists: message(STATUS) before the if block; align option wording AI-assisted implementation. Reviewed and tested on dual Intel Arc Pro B70 (32 GB each): test-backend-ops OK on both GPUs, single/dual-GPU Q4_K_M and Q8_0 bench correct, zeMemAllocDevice GTT delta confirmed <5 MiB per 4 GiB allocation (vs ~4 GiB shadow with sycl::malloc_device). Co-Authored-By: Claude Sonnet 4.6 * SYCL: remove unused cstdio/cstdlib includes from common.cpp Leftover from the deleted ggml_sycl_queue_supports_level_zero helper. Co-authored-by: Claude Sonnet 4.6 * Apply suggestions from code review Co-authored-by: Neo Zhang * SYCL: preserve Level Zero allocation path during early malloc * ci: fix Level Zero package conflict in Intel Docker build * ci: find Level Zero loader in oneAPI package step * ci: allow Windows SYCL package without Level Zero DLL --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Neo Zhang --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 29 +++++++++ ggml/src/ggml-sycl/common.cpp | 76 +++++++++++++++++++++++- ggml/src/ggml-sycl/common.hpp | 4 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 98 +++++++++++++++++++++++++------ 5 files changed, 187 insertions(+), 21 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4e65cd68b4e..bdeca34bf9f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -249,6 +249,7 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) +option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8f44c6ed080..180de92202d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -39,6 +39,18 @@ if (WIN32) set(CMAKE_CXX_COMPILER "icx") set(CMAKE_CXX_COMPILER_ID "IntelLLVM") endif() + # Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled) + if(GGML_SYCL_SUPPORT_LEVEL_ZERO) + if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH}) + set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH}) + if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}") + target_include_directories(ggml-sycl PRIVATE "${LEVEL_ZERO_V1_SDK_PATH}/include") + set(LEVEL_ZERO_V1_SDK_LIB_PATH "${LEVEL_ZERO_V1_SDK_PATH}/lib") + else() + message(WARNING "LEVEL_ZERO_V1_SDK_PATH set but folder not found: ${LEVEL_ZERO_V1_SDK_PATH}") + endif() + endif() + endif() endif() macro(detect_and_find_package package_name) @@ -93,6 +105,23 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") +message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}") +if (GGML_SYCL_SUPPORT_LEVEL_ZERO) + # Link against Level Zero loader for direct device memory allocation. + # Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging + # in the xe kernel driver during multi-GPU inference. + find_path(LEVEL_ZERO_INCLUDE_DIR level_zero/ze_api.h HINTS ${ONEAPI_ROOT}/include ${LEVEL_ZERO_V1_SDK_PATH}/include) + find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH) + if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR) + target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB}) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO) + message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}") + message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}") + else() + message(WARNING "Level Zero loader or headers not found, Level Zero support disabled") + endif() +endif() + # Link against oneDNN set(GGML_SYCL_DNNL 0) if(GGML_SYCL_DNN) diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 05fd5ef46c7..ae08abad81b 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -11,6 +11,10 @@ // #include "common.hpp" +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #include "ggml-backend-impl.h" #include "ggml-impl.h" @@ -55,6 +59,20 @@ bool gpu_has_xmx(sycl::device &dev) { return dev.has(sycl::aspect::ext_intel_matrix); } +static int ggml_sycl_get_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits::max(); int64_t sycl_down_blk_size = block_size; @@ -66,6 +84,61 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) { + return ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1) && + q.get_device().is_gpu() && + q.get_backend() == sycl::backend::ext_oneapi_level_zero; +} +#endif + +// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering +// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference. +// The decision is made from the queue and runtime env because large buffers can be +// allocated before ggml_check_sycl() initializes g_ggml_sycl_enable_level_zero. +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + void *ptr = nullptr; + auto ze_ctx = sycl::get_native(q.get_context()); + auto ze_dev = sycl::get_native(q.get_device()); +#ifdef ZE_RELAXED_ALLOCATION_LIMITS_EXP_NAME + ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, + nullptr, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE, + }; + ze_device_mem_alloc_desc_t alloc_desc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + &relaxed_desc, + 0, + 0, + }; +#else + ze_device_mem_alloc_desc_t alloc_desc = {ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, nullptr, 0, 0}; +#endif + ze_result_t r = zeMemAllocDevice(ze_ctx, &alloc_desc, size, 64, ze_dev, &ptr); + if (r == ZE_RESULT_SUCCESS && ptr) { + return ptr; + } + return nullptr; + } +#endif + return sycl::malloc_device(size, q); +} + +void ggml_sycl_free_device(void *ptr, sycl::queue &q) { + if (!ptr) return; +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + auto ze_ctx = sycl::get_native(q.get_context()); + zeMemFree(ze_ctx, ptr); + return; + } +#endif + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, q))); +} + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { @@ -75,8 +148,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str } if (extra->data_device[i] != nullptr && streams.size()>0) { ggml_sycl_set_device(i); - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(extra->data_device[i], *(streams[i])))); } } delete extra; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index eec36e8db9a..96bc1c98bd9 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -310,6 +310,10 @@ struct ggml_tensor_extra_gpu { optimize_feature optimized_feature; }; +extern int g_ggml_sycl_enable_level_zero; +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q); +void ggml_sycl_free_device(void *ptr, sycl::queue &q); + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); namespace sycl_ex = sycl::ext::oneapi::experimental; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 57cc4ffb6f7..f5d10b56de0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -30,6 +30,10 @@ #include #include +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif @@ -68,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -223,6 +228,27 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); +#else + g_ggml_sycl_enable_level_zero = 0; +#endif + if (g_ggml_sycl_enable_level_zero) { + // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { + auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); + if (!q.get_device().is_gpu()) { + continue; + } + if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + g_ggml_sycl_enable_level_zero = 0; + break; + } + } + } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); @@ -253,6 +279,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); #endif +#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO) + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -262,6 +293,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); #endif +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n"); +#endif #if GGML_SYCL_DNNL GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else @@ -371,7 +407,7 @@ struct ggml_backend_sycl_buffer_context { ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(dev_ptr, *stream))); } //release extra used by tensors @@ -504,8 +540,43 @@ catch (sycl::exception const &exc) { std::exit(1); } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_is_l0_discrete_gpu(sycl::queue &q) { + if (!q.get_device().is_gpu() || q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + return false; + } + + ze_device_handle_t ze_dev = sycl::get_native(q.get_device()); + ze_device_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + ze_result_t r = zeDeviceGetProperties(ze_dev, &props); + return r == ZE_RESULT_SUCCESS && !(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); +} +#endif + static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + // Use Level Zero direct copy for dGPU-to-dGPU transfers. + const bool l0_copy_supported = + ggml_sycl_is_l0_discrete_gpu(q_dst) && ggml_sycl_is_l0_discrete_gpu(q_src); + if (g_ggml_sycl_enable_level_zero && l0_copy_supported) { + auto ze_ctx = sycl::get_native(q_dst.get_context()); + auto ze_dev = sycl::get_native(q_dst.get_device()); + ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0, + 0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; + ze_command_list_handle_t cl; + ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl); + if (r == ZE_RESULT_SUCCESS) { + r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr); + zeCommandListDestroy(cl); + if (r == ZE_RESULT_SUCCESS) { + return; + } + } + } +#endif + // Host-staged copy char *host_buf = (char *)malloc(size); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); @@ -675,8 +746,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)ggml_sycl_malloc_device(size, *stream))); if (!dev_ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); return nullptr; @@ -917,18 +987,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if SYCL Buffer alloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); const queue_ptr stream = ctx->streams[i]; char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)ggml_sycl_malloc_device(size, *stream))); if (!buf) { char err_buf[1024]; snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); @@ -1306,7 +1368,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(b.ptr, *qptr))); pool_size -= b.size; } } @@ -1374,9 +1436,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *)ggml_sycl_malloc_device(look_ahead_size, *qptr))); if (!ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); return nullptr; @@ -1404,7 +1464,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } } GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(ptr, *qptr))); pool_size -= size; } }; @@ -3405,7 +3465,7 @@ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - return sycl::malloc(size, *stream, sycl::usm::alloc::device); + return ggml_sycl_malloc_device(size, *stream); } static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { @@ -3419,7 +3479,7 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - sycl::free(ptr, *stream); + ggml_sycl_free_device(ptr, *stream); } // RAII wrapper for temporary reorder buffers with optional host memory fallback. From 97ba44338fcda566e10644057135f08ea820ff60 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 14 May 2026 10:36:54 +0200 Subject: [PATCH 089/289] vulkan: fix matmul integer pipeline selection (llama/23005) * vulkan: fix matmul integer pipeline selection * gate pipeline creation with the right bools --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a556206d5..8c4cf9ef1db 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3954,13 +3954,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ + if (device->mul_mat ## ID ## _l_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ + if (device->mul_mat ## ID ## _m_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ + if (device->mul_mat ## ID ## _s_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ @@ -4131,11 +4131,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ + if (device->mul_mat ## ID ## _l_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ + if (device->mul_mat ## ID ## _m_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ + if (device->mul_mat ## ID ## _s_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -5716,12 +5716,12 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } - device->mul_mat_l_int[i] = true; - device->mul_mat_m_int[i] = true; - device->mul_mat_s_int[i] = true; - device->mul_mat_id_l_int[i] = true; - device->mul_mat_id_m_int[i] = true; - device->mul_mat_id_s_int[i] = true; + device->mul_mat_l_int[i] = device->mul_mat_l[i]; + device->mul_mat_m_int[i] = device->mul_mat_m[i]; + device->mul_mat_s_int[i] = device->mul_mat_s[i]; + device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i]; + device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i]; + device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i]; } From f0223903aa33894bcc6d7f23b7b7466dc95c33dd Mon Sep 17 00:00:00 2001 From: alex-spacemit Date: Thu, 14 May 2026 17:39:30 +0800 Subject: [PATCH 090/289] ggml-cpu: Add IME2 Instruction Support for the SpacemiT Backend (llama/22863) --- ggml/src/ggml-cpu/CMakeLists.txt | 13 + ggml/src/ggml-cpu/cmake/FindSMTIME.cmake | 32 + ggml/src/ggml-cpu/ggml-cpu.c | 12 + ggml/src/ggml-cpu/spacemit/ime.cpp | 2089 ++++-- ggml/src/ggml-cpu/spacemit/ime.h | 8 + ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp | 3363 ++-------- ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp | 5768 +++++++++++++++++ ggml/src/ggml-cpu/spacemit/ime_env.cpp | 320 + ggml/src/ggml-cpu/spacemit/ime_env.h | 55 + ggml/src/ggml-cpu/spacemit/ime_kernels.h | 201 +- ggml/src/ggml-cpu/spacemit/repack.cpp | 1795 +++++ ggml/src/ggml-cpu/spacemit/repack.h | 14 + ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp | 3178 +++++++++ ggml/src/ggml-cpu/spacemit/rvv_kernels.h | 95 + ggml/src/ggml-cpu/spacemit/spine_barrier.h | 34 + ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp | 760 +++ ggml/src/ggml-cpu/spacemit/spine_mem_pool.h | 32 + ggml/src/ggml-cpu/spacemit/spine_tcm.h | 409 ++ 18 files changed, 14706 insertions(+), 3472 deletions(-) create mode 100644 ggml/src/ggml-cpu/cmake/FindSMTIME.cmake create mode 100644 ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.h create mode 100644 ggml/src/ggml-cpu/spacemit/repack.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/repack.h create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_barrier.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_tcm.h diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 869c7b238bf..f3eccff7d72 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -450,12 +450,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) @@ -485,6 +495,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() if (GGML_CPU_RISCV64_SPACEMIT) # `xsmtvdotii' is only required for GCC >= 15. if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 00000000000..c8a4d4b4ec9 --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..7b05edf6b75 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -3011,7 +3015,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { /*.ith =*/ state->ith, @@ -3068,6 +3076,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925eaa..9563ea3e4bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include +#include +#include #include +#include #include +#include #include #include // for GGML_ASSERT #include #include - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,382 +63,490 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; +}; -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; +thread_local TLSContext tls_context; + +template constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v || std::is_same_v) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); + } else { + assert(false); + return 0; + } +} - const size_t k_blks = div_round_up(gemm_k, blk_len); +template constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return false; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return true; + } else { + assert(false); + return false; + } +} - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; +}; - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + size = GGML_PAD(size, sizeof(int64_t)); - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; + size = GGML_PAD(size, sizeof(int64_t)); - int32_t rows_remaining = m_count; + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; + size = GGML_PAD(size, sizeof(int64_t)); - rows_remaining -= rows_handled; + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } + return false; } -} -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return -1; -} -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks -}; + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; -template struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks -}; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + GGML_TENSOR_BINARY_OP_LOCALS -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; + int ith = params->ith; + int nth = params->nth; -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + [[maybe_unused]] const enum ggml_type type = src0->type; - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; + + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; + } } - } +#endif - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - } - return out; -} + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); } - } - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } + uintptr_t ws_ptr = reinterpret_cast(params->wdata); - return out; -} + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - constexpr int nrows_interleaved = 16; + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + const int64_t barrier_idx = static_cast(ith / 2); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); -} + ggml_barrier(params->threadpool); -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); - constexpr int nrows_interleaved = 16; + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast(w_data); + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); + + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; + + int32_t rows_remaining = m_row_real; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_UNUSED(data_size); -} + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); + spine_barrier_wait(cur_barrier); - constexpr int nrows_interleaved = 16; + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } - src += nrows_interleaved * nblocks; - } - return 0; + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); - GGML_UNUSED(data_size); -} + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; -namespace ggml::cpu::riscv64_spacemit { + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); -template -int repack(struct ggml_tensor *, const void *, size_t); + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} + uint8_t * b_col = reinterpret_cast(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); -} + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; -template class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + float * c_blk = output + m_start * gemm_n + n_start + ni; - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } } - default: - // GGML_ABORT("fatal error"); - break; + } } - return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS @@ -429,133 +554,381 @@ template class tensor_ int ith = params->ith; int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif - void * w_data = (void *) src0->data; - const float * feature = (const float *) src1->data; - float * output = (float *) dst->data; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); - GGML_ASSERT(batch_weight == 1); + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } } - std::vector qnbitgemm_args(batch_feature); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); - if constexpr (std::is_same_v) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } } - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; } - const uintptr_t ws_ptr = reinterpret_cast(params->wdata); - void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t barrier_idx = static_cast(ith / 2); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); - - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - } else { - while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - rows_tobe_handled -= 1; - m_idx += 1; + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array src_workspaces; + std::array dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); } } } - } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; - ggml_barrier(params->threadpool); + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; - } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast(tcm_buffer); + b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast(reinterpret_cast(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + int ir1 = src1_cur_start; - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); - void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } } } +#undef MMID_MATRIX_ROW } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); @@ -563,309 +936,464 @@ template class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); + return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); + ggml_barrier(params->threadpool); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + int repack(ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); +// Impl By IME1 +static const tensor_traits q4_0_16x32_q8_0; +static const tensor_traits q4_1_16x32_q8_0; +static const tensor_traits q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits q2_k_32x256_q8_0; +static const tensor_traits q3_k_32x256_q8_0; +static const tensor_traits q4_0_32x32_q8_0; +static const tensor_traits q4_1_32x32_q8_0; +static const tensor_traits q4_0_32x256_q8_0; +static const tensor_traits q4_1_32x256_q8_0; +static const tensor_traits q4_k_32x32_q8_0; +static const tensor_traits q6_k_32x32_q8_0; +static const tensor_traits q8_0_32x32_q8_0; +static const tensor_traits mxfp4_32x32_q8_0; +static const tensor_traits q5_k_32x32_q8_0; +static const tensor_traits q5_1_32x32_q8_0; +static const tensor_traits q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); +} // namespace ggml::cpu::riscv64_spacemit - p_input += gvl; - p_temp_output += gvl; - length -= gvl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif - mean_square = sqrt(mean_square + epsilon); +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; } +#endif } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } -}; - -static const tensor_traits q4_0_16x8_q8_0; -static const tensor_traits q4_1_16x8_q8_0; -static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; - -} // namespace ggml::cpu::riscv64_spacemit - -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acdae..6849dd95e05 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd9160..6acc6819dfb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include #include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL - - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" - - QUANTIZEM4ROW_STORE - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; - - ((float *) QuantA)[0] = scale_A; - - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); - - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), - (float) std::numeric_limits::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } - - return; - } - - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" - - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" - - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} -template -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} -template -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} -template -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } } } - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 00000000000..0c7a036a92a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast(a0) * static_cast((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast(a0) * static_cast(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast(acc) * static_cast(a_scale) * static_cast(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + + acc += static_cast(a_subblk[mi * 32 + 2 * bi]) * static_cast(b0) + + static_cast(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast(output_f16[ci + mi * NB_COLS]) * static_cast(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast(a0) * static_cast(b0); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1 (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0 (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast(qs_byte & 0x0F); + int8_t b_abs1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k; + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 00000000000..a13ba391da2 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector & result) { + static std::unordered_map spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 00000000000..a6ca06d26a4 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector & result); +}; + +struct spine_env_info { + std::vector core_info_list; + std::vector perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 75706341505..0a1fafffb25 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include #include +#include + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); - -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); - -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 00000000000..3c879c4b7a0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include +#include +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template int repack(ggml_tensor *, const void *, size_t); + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 00000000000..950cbde7593 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 00000000000..d2f89743622 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array pq_buffer; + std::array sinks_buffer; + std::array dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 00000000000..edddf957c21 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 00000000000..f897dad4b8a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic pending_; + SPINE_CACHE_ALIGN std::atomic rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 00000000000..1409423b145 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits::max() }; + bool valid{ false }; + }; + + std::vector::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector chunks_; + std::unordered_map allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast(backend), static_cast(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 00000000000..8740d2c99ef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 00000000000..f300d7d5c04 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include +#include +#include + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif From 592a8cd15d028f8d9a709e777641a9736a213565 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 13:05:52 +0300 Subject: [PATCH 091/289] logs : reduce (llama/23021) * logs : reduce * args : fix envs * server : fix build * common : print verbosity level at start * server : clean-up logs * server : print prompt processing timings + sampling params * minor : whitespaces --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fab7891c008..780dfe81bb3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -672,7 +672,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { ![[dev->mtl_device name] containsString:@"M6"] && ![[dev->mtl_device name] containsString:@"A19"] && ![[dev->mtl_device name] containsString:@"A20"]) { - GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); + GGML_LOG_INFO("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); dev->props.has_tensor = false; } From 13133ab299e94a413fed015841a424adec149b1c Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 14 May 2026 09:31:36 -0700 Subject: [PATCH 092/289] ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size. * ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 62a523365b9..4c4eda1cbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && + context.src2->ne[0] % kv_vec_head_align == 0; + // Compile with enough invocations to cover the largest reported subgroup. + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + kv_vec_head_dims_aligned && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && @@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = context.max_subgroup_size; if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 401c75c1230..78cb02be06d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { @@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = + std::max(reduce_sg_size, (uint32_t) std::min( + (uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const uint32_t kv_tile = decisions.kv_tile; - const uint32_t vec_nwg_cap = std::max( - 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { From e62d5893f4153226e761c2bc0ab80b28a63cb055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 14 May 2026 22:58:58 +0200 Subject: [PATCH 093/289] HIP: RDNA3 mma FA, faster AMD transpose, tune AMD (llama/22880) Adds RDNA3 support to the CUDA mma FA kernel. To make the RDNA3 tensor cores work with the FP16 accumulation for VKQ the tiles they need to be 32 logical units long in direction of the attention head; for head sizes 80 and 112 that are not exactly divided by 32 the regular length of 16 with FP32 accumulation is used instead. The longer tiles also enable more efficient transposition for a warp size of 32 which is why it's also used for RDNA4. However, this scrambles the data layout of the accumulators along the attention head dimension. To prevent accidental misuse I added another entry to ggml_cuda_mma::data_layout. I also tuned the kernel parameters for RDNA3, RDNA4, and CDNA1 in general, during which I discovered that the kernel can be made to work for head sizes up to 256 for CDNA. For RDNA3/4 I was not able to get better performance that the tile kernel for head sizes > 128. --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 319 ++++++++++++++++++++------- ggml/src/ggml-cuda/fattn.cu | 57 ++--- ggml/src/ggml-cuda/mma.cuh | 149 +++++++++++-- 3 files changed, 398 insertions(+), 127 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 43e22c5e5ee..a25e912c4d2 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -125,61 +125,107 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true); - // TODO tune specifically for RDNA - return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { - // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); - - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true); - // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy - // compile-time static_asserts even though the kernel guard prevents runtime execution. - // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. - return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { @@ -510,7 +556,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; @@ -712,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; + + // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed. + // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout. +#ifdef RDNA3 +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_j(l); + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } +#else #pragma unroll for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { const int i = (i0 + T_C_KQ::get_j(l0)) / 2; @@ -721,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } +#endif // RDNA3 } } @@ -827,13 +886,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2( - KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -901,9 +970,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) { static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { @@ -912,15 +980,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if constexpr (T_B_KQ::I == 8) { - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) // AMD matrix C is column-major. - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. - mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A); #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } @@ -953,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -965,7 +1033,7 @@ template struct mma_tile_sizes { using T_B_VKQ = tile<16, 8, half2>; // column-major using T_C_VKQ = tile<16, 8, half2>; // column-major }; -template<> struct mma_tile_sizes<8> { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile< 8, 8, half2>; // column-major using T_C_KQ = tile<16, 8, float>; // row-major @@ -973,8 +1041,60 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) -template struct mma_tile_sizes { +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#else +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -983,7 +1103,7 @@ template struct mma_tile_sizes { using T_C_VKQ = tile<16, 8, half2>; // column-major }; #else // Volta -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major @@ -1018,17 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; - using T_A_KQ = typename mma_tile_sizes::T_A_KQ; - using T_B_KQ = typename mma_tile_sizes::T_B_KQ; - using T_C_KQ = typename mma_tile_sizes::T_C_KQ; - using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; - using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; - using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -1061,6 +1181,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)]; #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta @@ -1269,12 +1391,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -1433,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { if constexpr (cols_per_warp == 8) { + static_assert(std::is_same_v, "bad VKQ type"); const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { @@ -1447,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } else { const int j0 = threadIdx.y*cols_per_warp; + if constexpr (std::is_same_v) { + if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) { #pragma unroll - for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - const int j = j0 + T_C_VKQ::get_i(l); - const int k = k1 + T_C_VKQ::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } else { + static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout"); + using T_C_VKQ_us = tile; // us == unscrambled +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]); +#pragma unroll + for (int l = 0; l < T_C_VKQ_us::ne; ++l) { + const int j = j0 + T_C_VKQ_us::get_i(l); + const int k = k1 + T_C_VKQ_us::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_us.x[l]; + } + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); + half * tile_Q_h = (half *) tile_Q; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = 2*k1 + T_C_VKQ::get_j(l); + + tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l]; + } } } } @@ -1532,7 +1697,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } template @@ -1559,7 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1585,14 +1750,14 @@ static __global__ void flash_attn_ext_f16( #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING #if defined(AMD_WMMA_AVAILABLE) - if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) { NO_DEVICE_CODE; return; } #endif // defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) - if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + if (ncols1*ncols2 < 16 || DKQ > 256) { NO_DEVICE_CODE; return; } @@ -1715,7 +1880,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) } template diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index e045b04f727..1c7777e8a71 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -19,13 +19,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } if constexpr (ncols2 <= 16) { - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + if (Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) || + (GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -477,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8); + int gqa_ratio_eff = 1; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { - int gqa_ratio_eff = 1; - const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; - } if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -500,41 +502,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { - if (can_use_vector_kernel) { - if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { - if (Q->ne[1] == 1) { - if (!gqa_opt_applies) { - return BEST_FATTN_KERNEL_VEC; - } - } - } else { - if (Q->ne[1] <= 2) { - return BEST_FATTN_KERNEL_VEC; - } - } + // AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes. + if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) { + return BEST_FATTN_KERNEL_MMA_F16; } - int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; + if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) { + return BEST_FATTN_KERNEL_MMA_F16; } - if (Q->ne[1] * gqa_ratio_eff <= 8) { - return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) { + return BEST_FATTN_KERNEL_MMA_F16; } - return BEST_FATTN_KERNEL_MMA_F16; } - // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { - const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); - // MMA vs tile crossover benchmarked on MI300X @ d32768: - // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) - // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) - if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { - return BEST_FATTN_KERNEL_MMA_F16; - } - // Fall through to tile kernel for small effective batch sizes. + // AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized. + if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) { + return BEST_FATTN_KERNEL_MMA_F16; } // If there are no tensor cores available, use the generic tile kernel: diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 79bb2934c5f..8d7c69dc3e8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -80,6 +80,7 @@ namespace ggml_cuda_mma { DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR @@ -312,13 +313,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -327,7 +334,15 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { +#ifdef RDNA3 + return l*2 + (threadIdx.x / 16); +#else + return (threadIdx.x / 16) * ne + l; +#endif // RDNA3 + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -338,13 +353,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -353,7 +374,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -516,12 +541,15 @@ namespace ggml_cuda_mma { if (I == 16 && J == 16) return true; if (I == 16 && J == 8) return true; if (I == 16 && J == 4) return true; + if (I == 32 && J == 8) return true; return false; } - static __device__ __forceinline__ int get_i(const int /*l*/) { - if constexpr (supported()) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16) { return threadIdx.x % 16; + } else if constexpr (I == 32) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -529,8 +557,10 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { + if constexpr (I == 16) { return l; + } else if constexpr (I == 32) { + return l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -644,6 +674,40 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED; + + static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size(); + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + }; + + static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret; +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]); + ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]); + } + return ret; +#else + NO_DEVICE_CODE; + GGML_UNUSED(t); +#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + } + #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -660,6 +724,21 @@ namespace ggml_cuda_mma { ret.x[0] = ggml_cuda_movmatrix(t.x[0]); ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2( + const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) { + tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret; +#pragma unroll + for (int l = 0; l < tile_float.ne; ++l) { + float tmp[2]; + int i = threadIdx.x / 16; + tmp[i] = tile_float.x[l]; + i ^= 1; + tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE); + ret.x[l] = make_half2(tmp[0], tmp[1]); + } return ret; } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) @@ -802,21 +881,35 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } - template + template static __device__ __forceinline__ void load_ldmatrix_trans( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE + static_assert(I == 16, "bad tile width"); + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - half * xh = (half *) t.x; + static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + if constexpr (I == 32) { #pragma unroll - for (int l = 0; l < t.ne; ++l) { - xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; - xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + for (int l0 = 0; l0 < t.ne/2; ++l0) { + const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2]; + const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2]; + + t.x[l0] = __lows2half2(tmp0, tmp1); + t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1); + } + } else { + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } } #else GGML_UNUSED_VARS(t, xs0, stride); @@ -972,6 +1065,20 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D; + const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4) + } + template static __device__ __forceinline__ void mma( tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { @@ -1296,6 +1403,22 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t * xD = (halfx16_t *) D.x; + const halfx16_t * xA = (const halfx16_t *) A.x; + const halfx16_t * xB = (const halfx16_t *) B.x; + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0); + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { From 18a61f44b63f34bdc05f7c88724b174b706ab149 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Thu, 14 May 2026 16:55:54 -0700 Subject: [PATCH 094/289] ggml-hexagon: cpy: add contiguous fast-path in reshape copy (llama/23076) --- ggml/src/ggml-hexagon/htp/cpy-ops.c | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index e5b9d350fd7..5c040a32224 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -88,6 +88,29 @@ static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + // Fast path: when both src0 and dst are contiguous in memory + // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. + const bool src0_contig = (nb00 == ct->src0_type_size) && + (nb01 == ne00 * nb00) && + (nb02 == ne01 * nb01) && + (nb03 == ne02 * nb02); + const bool dst_contig = (nb0 == ct->dst_type_size) && + (nb1 == ne0 * nb0) && + (nb2 == ne1 * nb1) && + (nb3 == ne2 * nb2); + + if (src0_contig && dst_contig) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); + } + } + return; + } + // dst counters int64_t k10 = 0; int64_t i11 = 0; From 23f956de336846ea28d7c2bc4c6d370216527203 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 20:06:23 +0800 Subject: [PATCH 095/289] llama + spec: MTP Support (llama/22673) * spec: support MTP * fix batch size * rename files * cont : simplify (llama/7) * MTP: clean-up (llama/9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (llama/11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (llama/13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 5 ++ ggml/src/ggml-backend-meta.cpp | 5 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 +++++++-- ggml/src/ggml-cuda/gated_delta_net.cu | 88 +++++++++++++------ ggml/src/ggml-metal/ggml-metal-device.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 46 ++++++++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +- .../vulkan-shaders/gated_delta_net.comp | 29 +++++- ggml/src/ggml.c | 12 +-- 10 files changed, 188 insertions(+), 57 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b75..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d45de8cce2..f6ffb2b3a1c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1db..d29a4bab2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; From 587dca0eda5168b4dcf77585b15ab33d179b3b27 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 May 2026 15:59:09 +0300 Subject: [PATCH 096/289] ggml : bump version to 0.12.0 (ggml/1494) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index bdeca34bf9f..4aac5094d1c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 12) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 3583e35e0d0d07993b96a07788e655acadc249a5 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 21 May 2026 17:28:08 +0530 Subject: [PATCH 097/289] ggml-alloc: fix out-of-bounds read in ggml_dyn_tallocr_remove_block (ggml/1492) --- ggml/src/ggml-alloc.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a4b01ccf8a1..3bda9abbe03 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -150,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { // shift all elements after idx by 1 to the left, overwriting the element at idx - for (int i = idx; i < chunk->n_free_blocks; i++) { + for (int i = idx; i < chunk->n_free_blocks - 1; i++) { chunk->free_blocks[i] = chunk->free_blocks[i+1]; } chunk->n_free_blocks--; From e78e69301721c8f804397f4cb356bc1a217b39f2 Mon Sep 17 00:00:00 2001 From: Ori Pekelman Date: Thu, 21 May 2026 12:00:16 +0000 Subject: [PATCH 098/289] ggml.h: correct ggml_silu_back arg docstring (a=dy, b=x) (ggml/1500) --- ggml/include/ggml.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 41566d41aef..f6725265504 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1189,8 +1189,8 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // a - x - // b - dy + // a - dy + // b - x GGML_API struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, From ef5ddecff9c9ece7946048aa4b193825bb916cb7 Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Sun, 17 May 2026 01:57:35 +0800 Subject: [PATCH 099/289] vulkan: removed duplicate #include in headers (llama/23144) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d29a4bab2e2..a296d0ab446 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -49,7 +49,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include -#include #include #include #include From c7dd64c6062adff284b04679612ffb3261eeffba Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 03:25:50 -0500 Subject: [PATCH 100/289] vulkan: fuse SSM_CONV + BIAS + SILU (llama/22653) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 126 ++++++++++++++++-- .../ggml-vulkan/vulkan-shaders/ssm_conv.comp | 12 +- 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a296d0ab446..d76d4819026 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -854,6 +854,8 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; + + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // they are overwritten, and one workgroup per row. So close enough. op_srcs_fused_elementwise[0] = true; op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index 6802b1fc955..4cd9b8da359 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -6,12 +6,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -45,6 +48,13 @@ void main() { } } + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); + } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } From e417ce7aebd9305ce99d8ba89230b9825dc409c2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 04:30:16 -0500 Subject: [PATCH 101/289] vulkan: Support unaligned tensors for ROPE (llama/22637) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 +++++++++++++++++ .../ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 7 +++++-- .../ggml-vulkan/vulkan-shaders/rope_params.glsl | 3 +++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d76d4819026..14eab8ea4de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1354,6 +1354,8 @@ struct vk_op_rope_push_constants { uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); @@ -10126,6 +10128,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -11270,6 +11281,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * (uint32_t)src0->ne[2], nb01, nb02, nb03, nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -11365,6 +11377,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 2e53459909d..03358793140 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop return; } - const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 2e2a7e14c66..3602485b943 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -26,6 +26,9 @@ struct rope_params { uint nb11; uint nb12; uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) From 50482cbd229dda298f753e0ebcb6f73a6c219920 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sun, 17 May 2026 11:31:20 +0200 Subject: [PATCH 102/289] vulkan: add cpy bf16 -> f32 pipelines (llama/22677) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 14 ++++++++++++-- .../ggml-vulkan/vulkan-shaders/contig_copy.comp | 8 ++++++-- ggml/src/ggml-vulkan/vulkan-shaders/copy.comp | 4 +++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 14eab8ea4de..d3fb19048d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -759,8 +759,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -4572,6 +4572,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4580,6 +4581,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7544,6 +7546,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -15974,6 +15983,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25bd..b3b182fb084 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c182..d55e13253a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d99b2b5d802..e3a9d61a558 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -731,6 +731,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -738,6 +739,7 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); From 9e96e0eaf1847b874dd9fd1734b0c9a469d6548c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ekstr=C3=B6m?= Date: Sun, 17 May 2026 14:12:11 +0300 Subject: [PATCH 103/289] ggml-vulkan/CMakeLists: add a check for SPIRV-Headers (llama/22009) * ci/run: set explicit SPIR-V Headers search path for macOS vulkan CI For whatever reason, the files are under additional sub-path `vulkan/` under the cmake directory, which does not match either current LunarG macOS Vulkan SDK structure (`lib/cmake/SPIRV-Headers`), nor what gets installed when you run the cmake build+install for SPIRV-Headers itself on at least Linux (`share/cmake/SPIRV-Headers`). This allows for SPIRV-Headers to be found, as currently the CI runner's setup does not seem to include the relevant path in list of search locations. * ggml-vulkan/CMakeLists: add a check for SPIRV-Headers This is installed by the project if it is built and installed. Receiving an error during the configuration step is generally preferred to receiving an error in the middle of a build. --- ggml/src/ggml-vulkan/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d0..6dbcea065b3 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,8 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +find_package(SPIRV-Headers REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) From 53736a3f0e91e132901a55e562252a3519b669dc Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 17 May 2026 18:00:10 +0200 Subject: [PATCH 104/289] CUDA: Continue directly including cuda/iterator (llama/23102) Cont of #22936, forgot to update one site --- ggml/src/ggml-cuda/top-k.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c9..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB From 4fb3ccabd38cf5c1103c545ed1d06cc552a44915 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Sun, 17 May 2026 15:05:11 -0600 Subject: [PATCH 105/289] feat: Support d_conv=15 for ssm-conv.cu (llama/23017) Branch: ModalityConditionalAdapters AI-usage: none Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ssm-conv.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc8..4c4daf85dc6 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } From 619262ad247dcaa319e06bd505f691ae1920b019 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:11:51 -0700 Subject: [PATCH 106/289] sycl: route small f32 matmuls to oneMKL, bypass oneDNN (llama/22150) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f5d10b56de0..ebe7c5b351c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2385,21 +2385,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst); From c65b082c947effa60f679f78065a82929e7e52b1 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:12:21 -0700 Subject: [PATCH 107/289] sycl: scalar SWAR byte-subtract in Q6_K MMVQ dot product (llama/22156) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/vecdotq.hpp | 99 ++++++++++++++++------------------ 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index d7770047424..16b2d65d271 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned( (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +static __dpct_inline__ int byte_sub_4(const int a, const int b) { + const uint32_t ua = static_cast(a); + const uint32_t ub = static_cast(b); + return static_cast(((ua | 0x80808080u) - ub) ^ 0x80808080u); +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar( + const int vl, const int vh, const int u0, const int u1, const int8_t sc0, + const int8_t sc1, const float d, const float d80, const float d81) { + static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2"); + + const int vil0 = (vl >> 0) & 0x0F0F0F0F; + const int vih0 = ((vh >> 0) << 4) & 0x30303030; + const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020); + + const int vil1 = (vl >> 4) & 0x0F0F0F0F; + const int vih1 = ((vh >> 4) << 4) & 0x30303030; + const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020); + + const float sumf = + d80 * (dpct::dp4a(vi0, u0, 0) * sc0) + + d81 * (dpct::dp4a(vi1, u1, 0) * sc1); + + return d * sumf; +} + static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, const uint8_t *values, int &val1, int &val2) { @@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, const int *__restrict__ u, const int8_t *__restrict__ scales, const float &d, const float *__restrict__ d8) { - - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float d, const float * __restrict__ d8) { - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4 * i]; - - const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, - dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d * sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, @@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl { const int8_t * scs = scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; + const int u0 = get_int_from_int8_aligned( + q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1); + const float d80 = (*(q8_1_ds + bq8_offset + 0))[0]; + const float d81 = (*(q8_1_ds + bq8_offset + 2))[0]; -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); - const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); - d8[i] = ds_values[0]; - } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81); } }; #define VDR_Q4_0_Q8_1_MMVQ 2 @@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq, const int8_t * scales = bq6_K->scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + const int u0 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 0].qs, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 2].qs, iqs % QI8_1); + const float d80 = bq8_1[bq8_offset + 0].ds[0]; + const float d81 = bq8_1[bq8_offset + 2].ds[0]; - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81); } From 0a11c9fe835b21d484e9161524d2e9dc2288fb12 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 13:39:36 -0700 Subject: [PATCH 108/289] ggml-hexagon: add PAD op HVX kernel (llama/23078) * ggml-hexagon: add PAD op HVX kernel Implements GGML_OP_PAD on the Hexagon HTP backend using HVX vectorized kernels. Supports zero-padding and circular padding across all 4 tensor dimensions. * hex-ggml: remove duplicate op cases (merge conflict) * hex-pad: fix editorconfig checks and macro alignment --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 18 + ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/pad-ops.c | 545 +++++++++++++++++++++++ 6 files changed, 569 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/pad-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d1c9da8329..c24a2305e4c 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2744,6 +2744,18 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; @@ -2857,6 +2869,8 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_PAD: return HTP_OP_PAD; + case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3416,6 +3430,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_PAD: + supp = ggml_hexagon_supported_pad(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index bcadac11f95..36f923243cd 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -38,6 +38,7 @@ add_library(${HTP_LIB} SHARED diag-ops.c solve-tri-ops.c gated-delta-net-ops.c + pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 92f02eac6e3..e500ce46212 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,5 +107,6 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 98db864dd42..985ded6f299 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_PAD, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 883a31d6163..85569f07289 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -595,6 +595,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_PAD: + return op_pad(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c new file mode 100644 index 00000000000..3abc3c2ead1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -0,0 +1,545 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +/* Circular wrap: maps any integer x into [0, n) */ +static inline uint32_t wrap_around(int32_t x, uint32_t n) { + return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n); +} + +/* Decompose a flat dst row index into (i1, i2, i3) */ +static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2, + uint32_t *i1, uint32_t *i2, uint32_t *i3) { + *i1 = ir % ne1; + *i2 = (ir / ne1) % ne2; + *i3 = ir / (ne1 * ne2); +} + +/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */ +static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t rp1, uint32_t ne1, + int32_t lp2, int32_t rp2, uint32_t ne2, + int32_t lp3, int32_t rp3, uint32_t ne3) { + return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) && + ((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) && + ((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3); +} + +/* Compute the DDR src row pointer for a zero-pad interior row */ +static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + (i1 - (uint32_t)lp1) * src->nb[1] + + (i2 - (uint32_t)lp2) * src->nb[2] + + (i3 - (uint32_t)lp3) * src->nb[3]; +} + +/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */ +static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1] + + wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2] + + wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3]; +} + +struct htp_pad_context { + struct htp_ops_context * octx; + + int32_t lp0, rp0; + int32_t lp1, rp1; + int32_t lp2, rp2; + int32_t lp3, rp3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; + + size_t type_size; + + // Row sizes for DMA kernel (populated when VTCM is available) + size_t src_row_size; + size_t src_row_size_aligned; + size_t dst_row_size; + size_t dst_row_size_aligned; +}; + +#define htp_pad_preamble \ + const struct htp_tensor * src = octx->src[0]; \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t nb00 = src->nb[0]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \ + const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \ + const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \ + const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \ + \ + const size_t type_size = pctx->type_size; \ + \ + const uint32_t row_start = pctx->nrows_per_thread * ith; \ + const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows); + + +#define htp_pad_dma_preamble \ + const size_t src_row_size = pctx->src_row_size; \ + const size_t src_row_size_aligned = pctx->src_row_size_aligned; \ + const size_t dst_row_size = pctx->dst_row_size; \ + const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \ + \ + uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \ + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \ + \ + dma_queue * dma = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX vectorized PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_u(dst_ptr, 0.0f, ne0); + } else { + const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (lp0 > 0) { + hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0); + } + + uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size; + if (nb00 == type_size) { + hvx_copy_f32_uu(dst_row_start, src_ptr, ne00); + } else { + for (uint32_t i = 0; i < ne00; i++) { + memcpy(dst_row_start + i * type_size, + src_ptr + (size_t)i * nb00, + type_size); + } + } + + if (rp0 > 0) { + hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline before the main loop begins. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + const uint8_t * src_ptr = interior + ? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL; + + // Interior row: real DMA (1 row) from DDR to VTCM. + // Border row: null DMA (nrows=0) + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + src_ptr ? src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, src_ptr ? 1 : 0); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops, + // push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + } else { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + + uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size; + + if ((uintptr_t)dst_interior % VLEN == 0) { + hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00); + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t ni1, ni2, ni3; + pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3); + const int next_interior = pad_is_interior(ni1, ni2, ni3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + const uint8_t * next_src_ptr = next_interior + ? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL; + + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX circular PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (nb00 == type_size) { + + if (lp0 > 0) { + if ((uint32_t)lp0 < 32) { + memcpy(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (uint32_t)lp0); + } + } + hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00); + if (rp0 > 0) { + if ((uint32_t)rp0 < 32) { + memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (size_t)rp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (uint32_t)rp0); + } + } + } else { + for (uint32_t i = 0; i < (uint32_t)lp0; i++) { + *(float *)(dst_ptr + i * type_size) = + *(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00); + } + for (uint32_t i = 0; i < ne00; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + for (uint32_t i = 0; i < (uint32_t)rp0; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA circular PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline. Every row is a real src DMA (no null DMAs). + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t pi1, pi2, pi3; + pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, assemble circular row in VTCM with + // aligned HVX ops, push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + + if (lp0 > 0) { + uint8_t * dst_left = dst_spad_cur; + const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size; + if ((uint32_t)lp0 < 32) { + memcpy(dst_left, src_left, (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0); + } + } + + { + uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size; + if ((uintptr_t)dst_mid % VLEN == 0) { + hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00); + } + } + + if (rp0 > 0) { + uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size; + if ((uint32_t)rp0 < 32) { + memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size); + } else { + if ((uintptr_t)dst_right % VLEN == 0) { + hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0); + } else { + hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0); + } + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t nri1, nri2, nri3; + pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_pad(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Only F32 supported + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + default: + FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int32_t lp0 = octx->op_params[0]; + const int32_t rp0 = octx->op_params[1]; + const int32_t lp1 = octx->op_params[2]; + const int32_t rp1 = octx->op_params[3]; + const int32_t lp2 = octx->op_params[4]; + const int32_t rp2 = octx->op_params[5]; + const int32_t lp3 = octx->op_params[6]; + const int32_t rp3 = octx->op_params[7]; + const int32_t circular = octx->op_params[8]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne00 = src0->ne[0]; + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1); + + const size_t src_row_size = (size_t)ne00 * type_size; + const size_t dst_row_size = (size_t)ne0 * type_size; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread + const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned); + + const int use_dma = (src0->nb[0] == (uint32_t)type_size) && + (ne00 >= 512) && + (octx->ctx->vtcm_base != NULL) && + (octx->ctx->vtcm_size >= vtcm_needed); + + if (use_dma) { + octx->src0_spad.size_per_thread = 2 * src_row_size_aligned; + octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + struct htp_pad_context pctx = { + .octx = octx, + .lp0 = lp0, .rp0 = rp0, + .lp1 = lp1, .rp1 = rp1, + .lp2 = lp2, .rp2 = rp2, + .lp3 = lp3, .rp3 = rp3, + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + .src_row_size = src_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size = dst_row_size, + .dst_row_size_aligned = dst_row_size_aligned, + }; + + FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n", + circular ? "-circ" : "", + use_dma ? "-dma" : "", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + + if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); } + else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); } + else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); } + else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); } + + return HTP_STATUS_OK; +} + From eb558f23cb6d2ea78d5e1401d62bebfa770446c9 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 14:04:57 -0700 Subject: [PATCH 109/289] hexagon: add support for TRI op (llama/22822) * Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context * addressed PR review comments for TRI op * hexagon: clang format * hex-unary: remove merge conflict markers * hex-ggml: remove duplicate op cases (merge conflict) * hex-ggml: fix editor config errors --------- Co-authored-by: Todor Boinovski Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 20 +++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/unary-ops.c | 113 ++++++++++++++++++++++++- 5 files changed, 137 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index c24a2305e4c..2f75e97ac66 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { return false; } + if (dst->type != GGML_TYPE_F32) { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; } + + return true; + + GGML_UNUSED(sess); +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_TRI: return HTP_OP_TRI; case GGML_OP_PAD: return HTP_OP_PAD; case GGML_OP_UNARY: @@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_TRI: + supp = ggml_hexagon_supported_tri(sess, op); + break; + case GGML_OP_PAD: supp = ggml_hexagon_supported_pad(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index e500ce46212..6fe3e6c7d85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_tri(struct htp_ops_context * octx); int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 985ded6f299..676e948a439 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_TRI, HTP_OP_PAD, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 85569f07289..12003c1fd8a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); + case HTP_OP_TRI: + return op_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index d4ae89ee6f0..1ce881353ec 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,7 +17,6 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" struct htp_unary_context { struct htp_ops_context * octx; @@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src, } } +static void tri_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + const uint32_t ir, + const struct htp_unary_context * uctx) { + + const int32_t ttype = op_params[0]; + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); + const uint32_t nvec = row_elems / VLEN_FP32; + const uint32_t nloe = row_elems % VLEN_FP32; + + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; + + for (uint32_t b = 0; b < num_rows; b++) { + const uint32_t abs_row = ir + b; + const uint32_t i01 = abs_row % ne01; + + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); + + uint32_t boundary; + int keep_left; + switch (ttype) { + case 0: boundary = i01; keep_left = 0; break; // keep col >= row + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row + case 3: boundary = i01; keep_left = 1; break; // keep col < row + default: boundary = 0; keep_left = 0; break; + } + if (boundary > row_elems) boundary = row_elems; + + // Full HVX vectors — each starts at a 128-byte aligned offset + for (uint32_t i = 0; i < nvec; i++) { + const uint32_t vec_start = i * VLEN_FP32; + const uint32_t vec_end = vec_start + VLEN_FP32; + if (keep_left) { + if (vec_end <= boundary) { + v_dst[i] = v_src[i]; + } else if (vec_start >= boundary) { + v_dst[i] = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); + } + } else { + if (vec_end <= boundary) { + v_dst[i] = zero; + } else if (vec_start >= boundary) { + v_dst[i] = v_src[i]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); + } + } + } + + // Tail elements (row_elems not a multiple of VLEN_FP32) + if (nloe > 0) { + const uint32_t vec_start = nvec * VLEN_FP32; + const uint32_t vec_end = vec_start + nloe; + HVX_Vector tail_val; + if (keep_left) { + if (vec_end <= boundary) { + tail_val = v_src[nvec]; + } else if (vec_start >= boundary) { + tail_val = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); + } + } else { + if (vec_end <= boundary) { + tail_val = zero; + } else if (vec_start >= boundary) { + tail_val = v_src[nvec]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); + } + } + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); + } + } +} + static void softplus_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_TRI: + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); + break; default: break; } @@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; + case HTP_OP_TRI: + op_type = "tri-f32"; + break; + default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; @@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return err; } +int op_tri(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src[0]->type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} + int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; From 3477fdb2e3c5b8828aa06da944c644b32bb7b6de Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 19 May 2026 09:42:36 +0300 Subject: [PATCH 110/289] rpc : keep last_graph_uid in the device context (llama/23273) With the introduction of MTP we can have multiple compute contexts for the same RPC device. In this case last_graph_uid is not updated properly when contexts are being switched. This patch fixes this by moving last_graph_uid to the device context, making sure it is always updated. closes: #23242 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1cb8f563d85..d3805772183 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -211,7 +219,6 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->last_graph_uid = cgraph->uid; + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { /* .endpoint = */ endpoint, /* .device = */ device, /* .name = */ dev_name, - /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { From 28edd0cb36581e9b4c4eb0a253c8cc9c9b13f9db Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Mon, 18 May 2026 23:44:02 -0700 Subject: [PATCH 111/289] sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle (llama/22153) * sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle Signed-off-by: Chun Tao * Use async mem ops for correctness when SYCL graphs are explicitly on. Signed-off-by: Tao, Chun --------- Signed-off-by: Chun Tao Signed-off-by: Tao, Chun Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ebe7c5b351c..2ea47f7153a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -72,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_use_async_mem_op_requested = 1; int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -304,6 +305,8 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); + GGML_LOG_INFO(" GGML_SYCL_USE_ASYNC_MEM_OP: %d\n", g_ggml_sycl_use_async_mem_op_requested); #ifdef SYCL_FLASH_ATTN GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); @@ -319,11 +322,11 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be - // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in - // other places. + // Async USM allocation/free is also useful outside the graph path: it avoids the host waits in the reorder + // staging path while preserving queue ordering semantics. Graph support still depends on the extension being + // available, but it no longer needs to control the non-graph fast path. #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC - g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + g_ggml_sycl_use_async_mem_op = g_ggml_sycl_use_async_mem_op_requested || !g_ggml_sycl_disable_graph; if (g_ggml_sycl_use_async_mem_op) { for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { From 6090f39f36056a9eb673c7d77bb90a4126d71fa9 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 18 May 2026 23:45:41 -0700 Subject: [PATCH 112/289] ggml-webgpu : extend GDN for K>1 (llama/23299) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 ++ .../wgsl-shaders/gated_delta_net.wgsl | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 78cb02be06d..921c12b41ac 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const uint32_t K = (uint32_t) src5->ne[1]; const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); @@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, (uint32_t) src0->ne[1], (uint32_t) (src2->ne[3] / src0->ne[3]), + K, scale_u32, }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index f9d98fda40b..d68520f8282 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -39,6 +39,7 @@ struct Params { neq1: u32, rq3: u32, + K: u32, scale: f32, }; @@ -62,11 +63,14 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_base = (seq_id * params.h + head_id) * state_size; + let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + let state_out_base = (seq_id * params.h + head_id) * state_size; + let state_size_per_snap = state_size * params.h * params.n_seqs; + let shift = i32(params.n_tokens) - i32(params.K); var state: array; for (var i = 0u; i < S_V; i++) { - state[i] = src_state[state_base + col * S_V + i]; + state[i] = src_state[state_in_base + col * S_V + i]; } var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; @@ -123,10 +127,22 @@ fn main( dst[attn_off + col] = attn_col * params.scale; attn_off += S_V * params.h; + if (params.K > 1u) { + let target_slot = i32(t) - shift; + if (target_slot >= 0 && target_slot < i32(params.K)) { + let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; + for (var i = 0u; i < S_V; i++) { + dst[slot_base + col * S_V + i] = state[i]; + } + } + } + workgroupBarrier(); } - for (var i = 0u; i < S_V; i++) { - dst[params.s_off + state_base + col * S_V + i] = state[i]; + if (params.K == 1u) { + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_out_base + col * S_V + i] = state[i]; + } } } From 459ff0707b51b0262536977d34e8ad2631063e68 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Tue, 19 May 2026 22:18:21 +0530 Subject: [PATCH 113/289] hexagon: enable support for NORM op (llama/23319) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 97 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 2f75e97ac66..ebeef3bdbaf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; @@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: case GGML_OP_L2_NORM: - supp = ggml_hexagon_supported_unary(sess, op); - break; - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 676e948a439..9d905a30133 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -88,6 +88,7 @@ enum htp_op_code { HTP_OP_GATED_DELTA_NET, HTP_OP_TRI, HTP_OP_PAD, + HTP_OP_NORM, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 12003c1fd8a..8e54536f619 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_ADD_ID: return op_binary(octx); + case HTP_OP_NORM: case HTP_OP_RMS_NORM: case HTP_OP_SCALE: case HTP_OP_SQR: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1ce881353ec..40d2d60153a 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + static void scale_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src, } } +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + static void sqr_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * // Process block in VTCM switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; From aca63e76386b239cbd1692629ddece4e7e26f03b Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 20 May 2026 02:40:13 +0530 Subject: [PATCH 114/289] hexagon: add MROPE and IMROPE support in HTP rope op (llama/23317) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 115 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ebeef3bdbaf..080fb7f47e3 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if (mode == GGML_ROPE_TYPE_VISION) { return false; } if (mode & 1) { diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 1d8b0796bc9..9901453e91e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -18,9 +18,11 @@ #include "htp-ops.h" #include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h +// Redefined the rope type constants as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_TYPE_MROPE 8 +#define HTP_ROPE_TYPE_IMROPE 40 #define HTP_ROPE_SPAD_NROWS 16 #define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) @@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } +// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling. +static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims, + uint32_t i0, float ext_factor, float mscale, + float * cache) { + float theta_extrap = theta; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; +} + static void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, @@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base, for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - float theta_extrap = theta / ff; - - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta_final = theta_interp; - float mscale_final = mscale; + theta *= theta_scale; + } +} - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). +// sections[4]: number of head dims assigned to each position component. +static void mrope_cache_init(const float pos_t, + const float pos_h, + const float pos_w, + const float pos_e, + const int32_t sections[4], + const bool is_imrope, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + const int sec_w = sections[0] + sections[1]; + const int sec_e = sec_w + sections[2]; + + float theta_t = pos_t; + float theta_h = pos_h; + float theta_w = pos_w; + float theta_e = pos_e; - // Get n-d magnitude scaling corrected for interpolation - mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + const int sector = (i0 / 2) % sect_dims; + + float theta; + if (is_imrope) { + // Interleaved: sector mod 3 selects component + if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; } + else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; } + else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; } + else { theta = theta_e; } + } else { + // Contiguous sections + if (sector < sections[0]) { theta = theta_t; } + else if (sector < sec_w) { theta = theta_h; } + else if (sector < sec_e) { theta = theta_w; } + else { theta = theta_e; } } - cache[i0 + 0] = cosf(theta_final) * mscale_final; - cache[i0 + 1] = sinf(theta_final) * mscale_final; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - theta *= theta_scale; + theta_t *= theta_scale; + theta_h *= theta_scale; + theta_w *= theta_scale; + theta_e *= theta_scale; } } @@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { uint64_t tt = HAP_perf_get_qtimer_count(); const int32_t mode = rctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + // MROPE and IMROPE use NEOX-style pairing for the rotation + const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE); // VTCM setup uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); @@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { if (i2 != prev_i2) { prev_i2 = i2; - const int32_t p = pos[i2]; - rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); + const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0; + if (is_mrope) { + // src1 holds four position arrays stacked along ne0: + // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3] + const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE); + mrope_cache_init( + (float) pos[i2], + (float) pos[i2 + ne2], + (float) pos[i2 + ne2 * 2], + (float) pos[i2 + ne2 * 3], + rctx->sections, is_imrope, + rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } else { + rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); From 37f17208c25aae5c93177599c05de6efc3b57446 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Tue, 19 May 2026 14:29:00 -0700 Subject: [PATCH 115/289] opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (llama/23303) * opencl: add q4_k moe support * opencl: add q5_k moe support * opencl: add q6_k moe support * opencl: adjust format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 6 + ggml/src/ggml-opencl/ggml-opencl.cpp | 948 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 385 +++++++ .../kernels/gemm_moe_q4_k_f32_ns.cl | 279 ++++++ .../kernels/gemm_moe_q5_k_f32_ns.cl | 284 ++++++ .../kernels/gemm_moe_q6_k_f32_ns.cl | 263 +++++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 151 +++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 156 +++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 137 +++ 9 files changed, 2601 insertions(+), 8 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index c6aba608736..f75d089b574 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS gemv_moe_q5_0_f32_ns gemm_moe_q5_1_f32_ns gemv_moe_q5_1_f32_ns + gemm_moe_q4_k_f32_ns + gemv_moe_q4_k_f32_ns + gemm_moe_q5_k_f32_ns + gemv_moe_q5_k_f32_ns + gemm_moe_q6_k_f32_ns + gemv_moe_q6_k_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0e511592d53..a3af8c2da41 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -558,6 +558,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; + cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; + cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; + cl_kernel kernel_convert_block_q6_k_trans4_ns, kernel_restore_block_q6_k_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -619,6 +622,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; + cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns; + cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns; + cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -981,6 +987,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_k_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -3071,6 +3083,108 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -4148,6 +4262,8 @@ struct ggml_tensor_extra_cl_iq4_nl { struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Scales for each super block. cl_mem s = nullptr; // Scales @@ -4176,12 +4292,18 @@ struct ggml_tensor_extra_cl_q4_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } } }; struct ggml_tensor_extra_cl_q5_K { // Lower 4 bits of quantized weights. cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Upper 1 bit of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4222,6 +4344,10 @@ struct ggml_tensor_extra_cl_q5_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } size_q = 0; size_qh = 0; @@ -4234,6 +4360,8 @@ struct ggml_tensor_extra_cl_q5_K { struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; + // Lower 4 bits as image1d_buffer_t + cl_mem ql_img = nullptr; // Upper 2 bits of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4267,6 +4395,10 @@ struct ggml_tensor_extra_cl_q6_K { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (ql_img != nullptr) { + CL_CHECK(clReleaseMemObject(ql_img)); + ql_img = nullptr; + } size_ql = 0; size_qh = 0; @@ -4700,7 +4832,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te // the quantizations here currently do not - they are only supported by Adreno with certain shapes if (op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || - op->src[0]->type == GGML_TYPE_Q5_1) { + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -6047,14 +6182,57 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - #endif +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6157,14 +6335,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; - #endif +#endif cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6232,6 +6454,79 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno MoE Q6_K kernel needs special transposed layout + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + size_t moe_size_ql = (size_t)(ggml_nelements(tensor) / 8) * sizeof(uint32_t); // 4 bits per element + size_t moe_size_qh = (size_t)(ggml_nelements(tensor) / 16) * sizeof(uint32_t); // 2 bits per element + size_t moe_size_s = size_s; + size_t moe_size_d = size_d; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = moe_size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + moe_size_ql, backend_ctx->alignment); + region.size = moe_size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + moe_size_qh, backend_ctx->alignment); + region.size = moe_size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for d + region.origin = align_to(previous_origin + moe_size_s, backend_ctx->alignment); + region.size = moe_size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for ql + cl_image_format img_format_ql = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_ql = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->ql } + }; + extra->ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_ql, &img_desc_ql, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; @@ -6825,6 +7120,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6901,6 +7230,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6974,7 +7337,44 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; @@ -13733,6 +14133,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13741,6 +14144,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, (void)extra0_q4_1; (void)extra0_q5_0; (void)extra0_q5_1; + (void)extra0_q4_K; + (void)extra0_q5_K; + (void)extra0_q6_K; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -14612,6 +15018,532 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q6_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q6_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q6_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 8f06d570587..312366984b6 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -664,6 +664,391 @@ kernel void kernel_restore_block_q5_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q4_k_trans4_ns( + __global struct block_q4_K * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_q[base + (p * 4 + 0) * ne01] = v.x; + dst_q[base + (p * 4 + 1) * ne01] = v.y; + dst_q[base + (p * 4 + 2) * ne01] = v.z; + dst_q[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_k_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q4_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_q[base + (p * 4 + 0) * ne01]; + qv[p].y = src_q[base + (p * 4 + 1) * ne01]; + qv[p].z = src_q[base + (p * 4 + 2) * ne01]; + qv[p].w = src_q[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q5_k_trans4_ns( + __global struct block_q5_K * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + for (int k = 0; k < 8; k++) { + uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0; + for (int bit = 0; bit < 8; bit++) { + b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit); + b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit); + b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit); + b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit); + } + uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24); + dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed; + } + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_qs[base + (p * 4 + 0) * ne01] = v.x; + dst_qs[base + (p * 4 + 1) * ne01] = v.y; + dst_qs[base + (p * 4 + 2) * ne01] = v.z; + dst_qs[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_k_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q5_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + for (int j = 0; j < 32; j++) b->qh[j] = 0; + for (int k = 0; k < 8; k++) { + uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01]; + uchar b0 = (uchar)(packed & 0xFF); + uchar b1 = (uchar)((packed >> 8) & 0xFF); + uchar b2 = (uchar)((packed >> 16) & 0xFF); + uchar b3 = (uchar)((packed >> 24) & 0xFF); + for (int bit = 0; bit < 8; bit++) { + b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k); + b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k); + b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k); + b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k); + } + } + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_qs[base + (p * 4 + 0) * ne01]; + qv[p].y = src_qs[base + (p * 4 + 1) * ne01]; + qv[p].z = src_qs[base + (p * 4 + 2) * ne01]; + qv[p].w = src_qs[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q6_k_trans4_ns( + __global struct block_q6_K * src0, + __global uint * dst_ql, + __global uint * dst_qh, + __global half * dst_d, + __global char * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = src0 + src_blk_offset; + + dst_d[dst_blk_offset] = b->d; + + uint4 qlv[8]; + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->ql[i*64 + 2*j]; + uchar x1 = b->ql[i*64 + 2*j + 1]; + uchar x2 = b->ql[i*64 + 32 + 2*j]; + uchar x3 = b->ql[i*64 + 32 + 2*j + 1]; + qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4); + qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0); + } + } + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qlv[p]; + dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x; + dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y; + dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z; + dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w; + } + + uint qhv[16] = {0}; + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + uchar h = b->qh[n*32 + l]; + int u = l / 16; + int bit_pos = (l % 16) * 2; + qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos; + qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos; + qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos; + qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos; + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + + for (int p = 0; p < 16; ++p) { + dst_qh[qh_base + p * ne01] = qhv[p]; + } + + __global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + #pragma unroll + for (int i = 0; i < 16; ++i) { + s_dst[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_k_trans4_ns( + __global uint * src_ql, + __global uint * src_qh, + __global half * src_d, + __global char * src_s, + __global struct block_q6_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + uint4 qlv[8]; + for (int p = 0; p < 8; ++p) { + qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01]; + qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01]; + qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01]; + qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01]; + } + + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo_02 = qlv_bytes[i*64 + j]; + uchar lo_13 = qlv_bytes[i*64 + j + 16]; + uchar hi_02 = qlv_bytes[i*64 + j + 32]; + uchar hi_13 = qlv_bytes[i*64 + j + 48]; + b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4)); + b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0)); + b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4)); + b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0)); + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + uint qhv[16]; + for (int p = 0; p < 16; ++p) { + qhv[p] = src_qh[qh_base + p * ne01]; + } + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + int u = l / 16; + int bit_pos = (l % 16) * 2; + uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03); + uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03); + uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03); + uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03); + b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6); + } + } + + __global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + for (int i = 0; i < 16; ++i) { + b->scales[i] = s_src[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..9d24aff6a20 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -0,0 +1,279 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q4_k(q4, a_f16, scale, minv) \ + a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \ + a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \ + a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \ + a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \ + a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \ + a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \ + a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \ + a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \ + a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \ + a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q4_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half (next 16 elements, same sub-block scale) + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..808a0c7db6a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -0,0 +1,284 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \ + a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \ + a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \ + a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q5_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uint * src0_qh, + __global uchar * src0_s, + __global half * src0_d, + __global half * src0_dm, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // qh is stored at sub-block granularity + uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0); + uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]); + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..a040335adfa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -0,0 +1,263 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 + +#define dequantize_q6_k(qs16, qh16, a_f16, scale) \ + a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q6_k_f32_ns( + __read_only image1d_buffer_t src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * 16; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; // 32-element group index + uint sb = sub / 8; // super-block index + uint j = sub % 8; // group within super-block + + // Load d for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + + // Load sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0); + uint qh_first16 = src0_qh[qh_base]; + uint qh_second16 = src0_qh[qh_base + ne01]; + + // First half (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 ql nibbles (2 uints) from image + uint2 q4x16; + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantize first 16 elements (scale0) + dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..13d79f2526f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -0,0 +1,151 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) { + float8 fp32x8; + fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv; + fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv; + fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv; + fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv; + fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv; + fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv; + fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv; + fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_k_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..f128d44340a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_k_f32_ns( + __global uint * src0_q, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // sub_block index = sb * 8 + j + uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01; + uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]); + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..526e609dc3a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -0,0 +1,137 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) { + float8 fp32x8; + fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q6_k_f32_ns( + __global uint * src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block + int scales_per_row = num_superblocks * 16; + + // Expert offsets in the transposed noshuffle layout + uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block + uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; // super-block index + uint j = ib % 8; // 32-element group within super-block + + // Load d for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + + // Load 2 sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + // Load 4 uints of ql + uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01; + uint4 regQL; + regQL.s0 = src0_ql[ql_base]; + regQL.s1 = src0_ql[ql_base + ne01]; + regQL.s2 = src0_ql[ql_base + ne01 * 2]; + regQL.s3 = src0_ql[ql_base + ne01 * 3]; + + // Load 2 uints of qh + uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01; + uint2 regQH; + regQH.s0 = src0_qh[qh_base]; + regQH.s1 = src0_qh[qh_base + ne01]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} From 0a0a34287e84641c34679c1a38fed299a56e9d4b Mon Sep 17 00:00:00 2001 From: ravel7524 <58877666+ravel7524@users.noreply.github.com> Date: Wed, 20 May 2026 03:52:21 +0200 Subject: [PATCH 116/289] ggml-cuda: tune RDNA3 Q6_K MMVQ nwarps (llama/23349) --- ggml/src/ggml-cuda/mmvq.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a38..73a0991e206 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + return 8; case GGML_TYPE_Q6_K: + return 2; case GGML_TYPE_IQ4_NL: return 8; default: From c58fc465dfed99c3e51a32e27a76d82f19f6481c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 May 2026 09:42:00 +0300 Subject: [PATCH 117/289] metal : optimize pad + cpy (llama/23354) * metal : optimize pad * metal : optinmize cpy * cont : better row packing in threadgroup --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 17 ++- ggml/src/ggml-metal/ggml-metal.metal | 128 ++++++++++++---------- 3 files changed, 90 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e288a27f992..ba006d9b31a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2e8..8506000b6c0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const int nth = MIN(args.ne00, nth_max); - const int nk0 = (args.ne00 + nth - 1)/nth; ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); @@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f6ffb2b3a1c..4cf9dbea946 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - if (K > 1u) { + if (K > 1) { const int target_slot = (int)t - shift; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; @@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl( } } - if (K == 1u) { + if (K == 1) { device float * dst_state = (device float *) (dst) + attn_size + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; @@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32( for (int64_t sx = x_min; sx < x_max; ++sx) { const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); const float w = wx * wy; - const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); sum += (*src_ptr) * w; wsum += w; } @@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32( const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; - const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); sum += (*src_ptr) * wx * wy; } } @@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32( } } -kernel void kernel_pad_f32( +template +kernel void kernel_pad_impl( constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; } - return; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; From 3fa19558f223461fa30164384a65f592d63577ca Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Wed, 20 May 2026 13:59:02 +0200 Subject: [PATCH 118/289] Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) (llama/22522) * Adds initial PDL setup. * Adds PDL barriers based on simple heuristic: place "sync" before first input pointer access, and "launch" after last write, e.g. to tensors like dst. * Further optimization pass of the first half of kernels * Optimized PDL barriers for the second batch of kernels * Further refinements after rebase. * Moves pdl logic to separate function, removes some whitespace * Strips post-hoc PDL logic * Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to overlap execution with previous kernels * Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL * Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL * Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx, to enable hip/musa compatibility * Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32 * Enrolls flash_attn_combine_results * Fix: Drops needless and broken check of CUDA arch for PDL. PDL either works or is without effect. * Enrolls flash-attention kernels to pdl * Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for kernels args. This fixes PDL. * Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via and template alias and template expansion * Enrolls all remaining kernels for qwen3-coder-next into PDL * Remove all PDL LC calls to create a baseline * Added LC according to internal guidance and tested kernel performance. * Enrols missing qwen3-5 kernels passively into PDL. * Kernel optimizations (LC signals) for qwen3.5 * Enrolls ssm-scan kernels into PDL * Adds GGML_CUDA_PDL command line option to toggle PDL. * Fix: Ada and lower compilation by guarding PDL calls correctly * Cleanup: Removes commented out GGML_CUDA_PDL_LC * Cleanup: Removes experimental comments * Adds 90-virtual to build script so that Hopper GPUs can leverage PDL. * Adds stricter checks to enable PDL, adds env-check to disable it, and removes now superfluous compile option to enable PDL. * Fix: Correct PDL en/disablement based on device-side arch check. Host side check is UB. Required moving from macros to inlined functions * Fix: default-disable PDL. Enable by setting GGML_CUDA_ENABLE_PDL=1 * Enable PDL by default for Hopper+ devices * Enrolls softcap_f32 and two flash_attn kernels into PDL. * Improves flash attn PDL barrier placement * Fix: Perf regression on ada; excludes ada and below from PDL launches * Improves some sync barrier placements * Drops superfluous constructor * Adds #endif guard comments * Reverts experimental change to top-k-moe.cu, which moved expensive allocations in front of the PDL barrier. It did not have a meaningful impact. * Exchanges GGML_CUDA_DISABLE_PDL with GGML_CUDA_PDL. IFF GGML_CUDA_PDL=0 PDL is disabled * Revert "Drops superfluous constructor". Adds const to remaining arguments This reverts commit 12b1d250da0089ae02a9bb71bbb3fd6d70f6f2f1. * Cleanup: Removes and fixes some comments and whitespace * Clarifies comment of sync-barrier position * Relocates and refactors PDL launch functions and accessories * Adds error checking to the regular kernel launch path * Drops "auto" in favor of "ggml_cuda_kernel_params" * Adds "const" to ggml_cuda_kernel_launch_params * [Whitespace] Adds final newline to common.cuh to make editorconfig CI job happy --- ggml/src/ggml-cuda/CMakeLists.txt | 3 +- ggml/src/ggml-cuda/binbcast.cu | 32 +++++----- ggml/src/ggml-cuda/common.cuh | 86 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/concat.cu | 5 +- ggml/src/ggml-cuda/cpy.cu | 20 +++++-- ggml/src/ggml-cuda/fattn-common.cuh | 28 +++++---- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1 + ggml/src/ggml-cuda/fattn-tile.cuh | 2 + ggml/src/ggml-cuda/fattn-vec.cuh | 3 + ggml/src/ggml-cuda/fattn-wmma-f16.cu | 1 + ggml/src/ggml-cuda/gated_delta_net.cu | 11 ++-- ggml/src/ggml-cuda/getrows.cu | 7 ++- ggml/src/ggml-cuda/mean.cu | 6 +- ggml/src/ggml-cuda/mmvf.cu | 12 ++-- ggml/src/ggml-cuda/mmvq.cu | 11 ++-- ggml/src/ggml-cuda/norm.cu | 46 ++++++++++---- ggml/src/ggml-cuda/quantize.cu | 7 ++- ggml/src/ggml-cuda/reduce_rows.cuh | 2 + ggml/src/ggml-cuda/rope.cu | 15 +++-- ggml/src/ggml-cuda/scale.cu | 5 +- ggml/src/ggml-cuda/set-rows.cu | 11 +++- ggml/src/ggml-cuda/softcap.cu | 5 +- ggml/src/ggml-cuda/ssm-conv.cu | 8 ++- ggml/src/ggml-cuda/ssm-scan.cu | 27 +++++---- ggml/src/ggml-cuda/sumrows.cu | 12 ++-- ggml/src/ggml-cuda/topk-moe.cu | 47 ++++++++------- ggml/src/ggml-cuda/unary.cu | 10 +++- 27 files changed, 310 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..d3953eee962 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 90 == Hopper H100/200, needs CUDA v11.8 # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND) list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index adb4d5f0cb9..c25f42b32bb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -2,6 +2,9 @@ #include #include +template +using type_for_index = T; + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const int s12, const int s13, src1_ptrs... src1s) { + ggml_cuda_pdl_lc(); const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + ggml_cuda_pdl_sync(); for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); + ggml_cuda_pdl_sync(); float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast_unravel...>, launch_params, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, - ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, - s00, s01, s02, s03, - s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast...>, launch_params, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, - s00 ,s01, s02, s03, - s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, s00, s01, s02, s03, - s10, s11, s12, s13); + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } } @@ -333,6 +328,7 @@ static __global__ void k_repeat_back( } T sum = 0; + ggml_cuda_pdl_sync(); for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..9c73fe7e6fa 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,6 +5,7 @@ #include "ggml-cuda.h" #include +#include #include #if defined(GGML_USE_HIP) @@ -27,6 +28,7 @@ #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -50,6 +52,7 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_HOPPER 900 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 @@ -107,6 +110,24 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +# define GGML_CUDA_USE_PDL +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 + +static __device__ __forceinline__ void ggml_cuda_pdl_sync() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaGridDependencySynchronize(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + +static __device__ __forceinline__ void ggml_cuda_pdl_lc() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaTriggerProgrammaticLaunchCompletion(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device { const void * gate_bias = nullptr; ggml_glu_op glu_op; }; + +struct ggml_cuda_kernel_launch_params { + dim3 block_nums; + dim3 block_dims; + size_t shmem; + cudaStream_t stream; + + // size_t shmem + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} + + // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} +}; + +#if defined(GGML_CUDA_USE_PDL) +struct ggml_cuda_pdl_config { + cudaLaunchAttribute attr; + cudaLaunchConfig_t cfg; + + ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + + cfg = {}; + cfg.gridDim = params.block_nums; + cfg.blockDim = params.block_dims; + cfg.dynamicSmemBytes = params.shmem; + cfg.stream = params.stream; + cfg.attrs = &attr; + cfg.numAttrs = 1; + } + + // Delete due to &attr + ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; + +}; +#endif //defined(GGML_CUDA_USE_PDL) + + +template +static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { +#if defined(GGML_CUDA_USE_PDL) + + static const bool env_pdl_enabled = []() { + const char * env = getenv("GGML_CUDA_PDL"); + return env == nullptr || std::atoi(env) != 0; + }(); + + if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + auto pdl_cfg = ggml_cuda_pdl_config(launch_params); + + CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); + return; + } +#endif //defined(GGML_CUDA_USE_PDL) + + kernel<<>>(std::forward(args)... ); + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 102f944f924..adba4d522a4 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont const int64_t n = ne0 * ne1 * ne2; + ggml_cuda_pdl_sync(); for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { if constexpr (dim == 0) { const int64_t row = i / ne0; @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x, const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { - concat_f32_cont<0> - <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d208acf2d5f..121472ec228 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + ggml_cuda_pdl_lc(); const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + ggml_cuda_pdl_sync(); cpy_1(cx + x_offset, cdst + dst_offset); } @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; int cur_tile_buf = 0; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const const src_t * x = (const src_t *) cx; dst_t * dst = (dst_t *) cdst; + ggml_cuda_pdl_sync(); dst[i] = ggml_cuda_cast(x[i]); } @@ -192,8 +198,8 @@ cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar_contiguous<<>> - (cx, cdst, ne); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_contiguous, launch_params, cx, cdst, ne); } template @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda( GGML_ASSERT(grid_z < USHRT_MAX); dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_scalar_transpose<<>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..debcb6e5447 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max( if (tid < WARP_SIZE) { buf_iw[tid] = 1; } + ggml_cuda_pdl_sync(); __syncthreads(); int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; @@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j_z, const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; + ggml_cuda_pdl_lc(); const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + ggml_cuda_pdl_sync(); // Load the partial result that needs a fixup float dst_val = *dst; float max_val; @@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general( float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; + ggml_cuda_pdl_sync(); { dst_val = *dst; @@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { + ggml_cuda_pdl_lc(); // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; + ggml_cuda_pdl_sync(); for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } @@ -1146,7 +1152,9 @@ void launch_fattn( const uint3 ne01 = init_fastdiv_values(Q->ne[1]); GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<>>( + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, @@ -1176,9 +1184,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; - flash_attn_stream_k_fixup_uniform - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, gqa_ratio, bpt, fd0, fd1, fd2); } else if (ntiles_dst % blocks_num.x != 0) { @@ -1193,9 +1201,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup_general - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], gqa_ratio, total_work, fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } @@ -1204,9 +1212,9 @@ void launch_fattn( const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream); + ggml_cuda_kernel_launch(flash_attn_combine_results, launch_params, + dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a25e912c4d2..4871b90df86 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1724,6 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7b0a5e5cf49..fac76f13593 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -894,6 +894,8 @@ static __global__ void flash_attn_tile( } float KQ_sum[cpw] = {0.0f}; + ggml_cuda_pdl_sync(); + // Load Q data, convert to FP16 if fast: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..b0a6cf67f1a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec( #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + + ggml_cuda_pdl_sync(); if constexpr (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f19defbff93..4b6f6501094 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + ggml_cuda_pdl_sync(); const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index b4c9845e7a7..018d5d37d47 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,4 +1,5 @@ #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) @@ -53,6 +54,7 @@ gated_delta_net_cuda(const float * q, float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + ggml_cuda_pdl_sync(); #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; @@ -189,28 +191,29 @@ static void launch_gated_delta_net( int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 36b840e8148..457b695eb2a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -11,6 +11,7 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -48,6 +49,8 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_lc(); + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float( float sum = 0.0f; + ggml_cuda_pdl_sync(); for (int64_t i = 0; i < nrows_grad; ++i) { if (rows[i] != dst_row) { continue; @@ -156,7 +160,8 @@ static void get_rows_cuda_float( GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); const uint3 ne12_fdv = init_fastdiv_values(ne12); - k_get_rows_float<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream}; + ggml_cuda_kernel_launch(k_get_rows_float, launch_params, src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 49af5389957..a8f6046e46d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d9147202429..09d95f309b4 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f( int channel_y; int sample_dst; + ggml_cuda_pdl_sync(); if constexpr (is_multi_token_id) { // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case token_idx = blockIdx.z; @@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f( static_assert(std::is_same_v, "unsupported type"); } + ggml_cuda_pdl_lc(); #pragma unroll for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); @@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion( const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream}; + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 73a0991e206..13b8b855282 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -424,6 +424,7 @@ static __global__ void mul_mat_vec_q( uint32_t channel_y; uint32_t sample_dst; + ggml_cuda_pdl_sync(); channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; sample_dst = blockIdx.z; @@ -683,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -693,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ef98f675aa7..09d9f3a7d62 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -18,6 +18,7 @@ static __global__ void norm_f32( float2 mean_var = make_float2(0.0f, 0.0f); + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; mean_var.x += xi; @@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int j = start; j < end; j += block_size) { tmp += x[j]; } @@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x, const uint3 add_nrows_packed = make_uint3(0, 0, 0), const uint3 add_nchannels_packed = make_uint3(0, 0, 0), const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + ggml_cuda_pdl_lc(); const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32( float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xfi = xf[col]; sum_xx += xfi * xfi; @@ -253,6 +258,7 @@ static __global__ void l2_norm_f32( float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -261,6 +267,7 @@ static __global__ void l2_norm_f32( // sum up partial sums extern __shared__ float s_sum[]; tmp = block_reduce(tmp, s_sum); + ggml_cuda_pdl_lc(); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -300,10 +307,19 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params, + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } @@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -399,10 +423,12 @@ static void l2_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 52f664719ae..49516965cad 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -6,6 +6,7 @@ static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + ggml_cuda_pdl_lc(); const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { @@ -28,6 +29,7 @@ static __global__ void quantize_q8_1( const int64_t ib = i_cont / QK8_1; // block index const int64_t iqs = i_cont % QK8_1; // quant index + ggml_cuda_pdl_sync(); const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1( const int64_t i3 = blockIdx.z / ne2; const int64_t i00 = i0; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda( const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream); + ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index de240fd4413..5895d3bf8e5 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r const int num_unroll = 8; float temp[num_unroll]; float sum_temp[num_unroll] = { 0.0f }; + + ggml_cuda_pdl_sync(); for (int i = col; i < ncols;) { for (int j = 0; j < num_unroll; ++j) { if (i < ncols) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..e20a5cb6bed 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x, const float * freq_factors, const int64_t * row_indices, const int set_rows_stride) { + ggml_cuda_pdl_lc(); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne00) { @@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. @@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; @@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); const int sect_dims = sections.v[0] + sections.v[1]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; @@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x, const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream}; if (freq_factors == nullptr) { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } @@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x, const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 0ddeff6a175..7b2e59a4383 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -3,9 +3,11 @@ #define MAX_GRIDDIM_X 0x7FFFFFFF static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + ggml_cuda_pdl_lc(); int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + ggml_cuda_pdl_sync(); for (int64_t i = tid; i < nelements; i += stride) { dst[i] = scale * x[i] + bias; } @@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, nelements); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..e14f96b824c 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; @@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + ggml_cuda_pdl_lc(); const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; @@ -203,9 +206,11 @@ static void set_rows_cuda( const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); - k_set_rows<<>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, - s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, - ne11_fd, ne12_fd); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream); + ggml_cuda_kernel_launch(k_set_rows, launch_params, + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 40dfe45d65c..9f0fa1051cf 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,18 +1,21 @@ #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = tanhf(scale * x[i]) * softcap; } static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; - softcap_f32<<>>(x, dst, scale, softcap, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k); } // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4c4daf85dc6..48787b4b890 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "ssm-conv.cuh" #include "unary.cuh" @@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { + ggml_cuda_pdl_lc(); GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; @@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float x[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f }; + ggml_cuda_pdl_sync(); #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; @@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_conv_f32, launch_params, src0, src1, bias, src0_nb0, src0_nb1, + src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c1d4e2bc8df..412980376ac 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1) const int64_t s_off, const int64_t d_inner, const int64_t L_param) { const size_t L = L_template == 0 ? L_param : L_template; + ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); @@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1) const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + ggml_cuda_pdl_sync(); // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); @@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); switch (n_tok) { case 1: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 2: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 3: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 4: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 5: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 6: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 7: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 8: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; default: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4025771aadb..0003658ca95 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } } @@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if ((nrows / nsm) < 2) { // Increase num threads to 512 for small nrows to better hide the latency const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { // Enough active SMs to hide latency, use smaller blocks to allow better scheduling const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 3020e5c7433..da20c9aab7c 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * wt[i] = -INFINITY; } + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; @@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * output_weights[i] = 0.f; } + ggml_cuda_pdl_lc(); for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (n_expert) { case 1: - topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 576: - topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2aeba26f414..4cb805fa601 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) { template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = (T)op((float)x[i]); } template static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_op_kernel, launch_params, x, dst, k); } template @@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + ggml_cuda_pdl_lc(); const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + ggml_cuda_pdl_sync(); dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_gated_op_kernel, launch_params, x, g, dst, k, n, o0, o1); } template From b93a5ba605580e6dc05d4ee78f5894bfa6021ffc Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 20 May 2026 07:39:01 -0700 Subject: [PATCH 119/289] hexagon: HMX quantized matmul rework (llama/23368) * hmx-mm: update debug logging in hmx-mm * hmx-mm: update dequant logic to use HVX_vector_x2/4 * hmx-mm: remove non-pipelined version of the quantize matmul It seems that we don't reall need non-pipelined version * hmx-mm: use activation depth mode and update naming Co-authored-by: Kim-Chyan Gan * hex-mm: minor hmx matmul naming updates * hmx-mm: remove unused vars * snapdragon: scripts bump default ubatch-size to 1K * hexagon: combine HMX and power and clock settings into a single set_power call * hmx-mm: remove leftover of the scale repl helper * hexagon: fix editconf error --------- Co-authored-by: Kim-Chyan Gan --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 642 ++++++--------------- ggml/src/ggml-hexagon/htp/hmx-ops.h | 13 +- ggml/src/ggml-hexagon/htp/main.c | 38 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 10 +- 4 files changed, 196 insertions(+), 507 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index e05ccfd5fc7..3ef0bcdb26d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32 // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using // full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. -static inline void dequantize_x4x2_q4_0_x4groups_hvx( +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt, - HVX_Vector out[4]) { + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); @@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - volatile HVX_Vector vscale = hvx_vmemu(scales_4); - + HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = v_hi; // group2 already in [0:63] + HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ + v_hi /* group2 already in [0:63] */ }; + return r; } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed } // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, bool upper_nibbles, int sub_blk_base, const HVX_Vector vlut_cvt, - mxfp4_scales_t scales, - HVX_Vector out[4]) { + mxfp4_scales_t scales) { HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; @@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12 v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - out[0] = v_lo; - out[1] = Q6_V_vror_VR(v_lo, 64); - out[2] = v_hi; - out[3] = Q6_V_vror_VR(v_hi, 64); + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; } // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. @@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); if (row1 < n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); } for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; @@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -// - -#define FALLBACK_TO_STANDARD 1 - // C += AB static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, @@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, - float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - const size_t vtcm_budget = ctx->vtcm_size; - - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; - } - - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } - - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - uint8_t *dst = vtcm_scratch0; - const uint8_t *src = w + nc * row_stride; - const size_t n_rows = n_blk_sz; - const size_t src_stride = row_stride; - const size_t dst_stride = sub_row_stride; - const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - const size_t scale_off = full_qrow + blk_start * scale_blk_size; - const size_t scale_width = nb_sub * scale_blk_size; - - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); - } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); - } + Q6_mxmem_AR_after_hf(accum_tile, 0); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); -#endif - return 0; } -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { return -1; } - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - // Only pays off when the chunker yields >=2 n-chunks, so the main loop can - // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for - // double-buffered output and the worker-dispatch overhead are pure loss. - // Try pipeline costs first; fall back to sequential if the layout collapses - // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. - const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - const size_t seq_per_mn = sizeof(__fp16); // O x 1 - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - bool use_pipeline = false; - - if (m >= 128) { - size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && - hmx_ceil_div((size_t) n, nc) >= 2) { - m_chunk_n_rows = mc; - n_chunk_n_cols = nc; - vtcm_used = used; - use_pipeline = true; - } - } + const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; } - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); TIMER_DEFINE(weight_load); @@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - if (!use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - TIMER_STOP(activation_load); + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - - const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); - } + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - // Dequant + vscatter writes directly to [K, N] transposed tiles. - // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } else { - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } + } - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); - } + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } - - hmx_queue_suspend(ctx->hmx_queue); } + hmx_queue_suspend(ctx->hmx_queue); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { const int r2 = hmx_matmul_batch_r2(params); const int r3 = hmx_matmul_batch_r3(params); @@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_ (size_t) (dst_b3 / r3) * params->src0_nb3); } -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (const float *) ((const uint8_t *) params->activation + (size_t) dst_b2 * params->src1_nb2 + (size_t) dst_b3 * params->src1_nb3); } -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (float *) ((uint8_t *) params->dst + (size_t) dst_b2 * params->dst_nb2 + (size_t) dst_b3 * params->dst_nb3); } -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { int ret = 0; for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } if (!params->m || !params->k || !params->n) { return -1; } if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } @@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (group_size <= 1) { FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } // Grouped path: reuse interleaved weight across all q_heads sharing a @@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads @@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 @@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index 1c78ffadd1c..f114edb822f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -33,14 +33,14 @@ typedef struct { size_t src1_nb3; size_t dst_nb2; size_t dst_nb3; -} hmx_matmul_w16a32_batched_params_t; +} hmx_matmul_f16_f32_batched_params_t; // HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output // act_stride: activation row stride in elements (= k for contiguous, or // nb[1]/sizeof(float) for permuted tensors like attention Q). // weight_stride: weight row stride in elements (= k for compact weights, or // nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const __fp16 *permuted_weight, @@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, int act_stride, int weight_stride); -// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batched F16 wrapper over hmx_mat_mul_f16_f32. // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params); +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, +// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8e54536f619..e8619388478 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -87,35 +87,37 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { - // Power on HMX + // Power on HMX and set HMX clock HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX; - request.hmx.power_up = TRUE; - FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "Error setting HMX clock."); return err; } } - -#if __HVX_ARCH__ >= 75 +#else { - // Set HMX clock + // Power on HMX HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX_v2; - request.hmx_v2.set_clock = TRUE; - request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; - FARF(ALWAYS, "Setting HMX clock\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "Error powering on HMX."); return err; } } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617fa..46fc5602dc9 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) { // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { return op_matmul_hvx(octx); } @@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) { if (src0->type == HTP_TYPE_F16) { if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { + hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, @@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) { .dst_nb2 = dst->nb[2], .dst_nb3 = dst->nb[3], }; - ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + ret = hmx_matmul_f16_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, m_total, k, n, act_stride, wgt_stride); } } else { - ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, m_total, k, n, (int) src0->type); } From ad717a6de0727b64345d7141094e7b87c43952a1 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Wed, 20 May 2026 17:15:13 +0200 Subject: [PATCH 120/289] vulkan: optimize operations in the IM2COL shader (llama/22685) * vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index ba4c2103f0c..f4130d223b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) { const uint KHKW = p.KH * p.KW; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; + + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; + + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; + + const uint delta_ic_offset = delta_ic * p.offset_delta; + + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif + uint wg_x = gl_WorkGroupID.x; do { const uint wg_offset = wg_x * 512; - [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { - const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; + uint chw_idx = wg_offset + gidx; + + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; + + uint ky = rem / p.KW; + uint kx = rem % p.KW; + uint ic_offset = src_batch + ic * p.offset_delta; + + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif + + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { if (chw_idx >= p.CHW) { return; } - const uint ic = chw_idx / KHKW; - const uint rem = chw_idx - ic * KHKW; - const uint ky = rem / p.KW; - const uint kx = rem - ky * p.KW; - - const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); A_TYPE val = A_TYPE(0); - if (iih < p.IH && iiw < p.IW) { - val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; } #if BDA - D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); - out_ptr.d = D_TYPE(val); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[dst_row + chw_idx] = D_TYPE(val); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; } wg_x += gl_NumWorkGroups.x; From 896718eacf0fd975368784a917d2e4d1856a6e70 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 20 May 2026 09:57:36 -0700 Subject: [PATCH 121/289] opencl: refactor backend initilization (llama/23318) * opencl: refactor initialization * opencl: refactor GPU identification * opencl: rename for consistency * opencl: cache global mem size in dev_ctx * opencl: adjust log level * opencl: load argsort and flash_attn kernels in supports_op * argsort kernel must be built for supports_op for querying the max workgroups * flash_attn kernel has many variants, only load them when needed --- ggml/src/ggml-opencl/ggml-opencl.cpp | 429 ++++++++++++++++----------- 1 file changed, 254 insertions(+), 175 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a3af8c2da41..5fc46f789ec 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -375,6 +375,11 @@ struct ggml_backend_opencl_device_context { ggml_backend_buffer_type buffer_type; cl_context context = nullptr; + + GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; + ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + + size_t global_mem_size = 0; }; // backend context @@ -384,6 +389,18 @@ struct ggml_backend_opencl_context { cl_device_id device; std::string device_name; + ggml_cl_version platform_version; + ggml_cl_version opencl_c_version; + + // argsort is loaded in supports_op because its availability depends on how + // many workgroups are allowed, which requires kernel compilation. + bool kernels_loaded_argsort = false; + // flash attn is loaded in supports_op because it contains multiple variants + // and takes time to compile, so we want to only compile it when needed. + bool kernels_loaded_flash_attn = false; + // rest of the kernels are currently always loaded in alloc_buffer. + bool kernels_loaded = false; + std::string driver_version; GPU_FAMILY gpu_family; @@ -781,6 +798,8 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { + clFinish(queue); + ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING @@ -793,6 +812,9 @@ struct ggml_backend_opencl_context { // All registered devices with a default device in the front. static std::vector g_ggml_backend_opencl_devices; +// All device contexts associated with the devices above. +// The devices live as long as the process, so do the contexts. +static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { std::ifstream ifs(path); @@ -836,12 +858,120 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { +static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // argsort + if (!backend_ctx->kernels_loaded_argsort) { + cl_int err; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + backend_ctx->kernels_loaded_argsort = true; + } +} + +static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // flash_attn + if (!backend_ctx->kernels_loaded_flash_attn) { + cl_int err; + + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; + } + backend_ctx->kernels_loaded_flash_attn = true; + } + } +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { + if (backend_ctx->kernels_loaded) { + return; + } + cl_int err; // compiler options for general kernels auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); std::string compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; @@ -1986,89 +2116,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // flash_attn - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - GGML_LOG_CONT("."); - } - } - - // argsort - { -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "argsort.cl.h" - }; -#else - const std::string kernel_src = read_file("argsort.cl"); -#endif - backend_ctx->program_argsort_f32_i32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); - GGML_LOG_CONT("."); - } - // div { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3335,13 +3382,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); + backend_ctx->kernels_loaded = true; } // XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // XXX static bool initialized = false; // XXX static ggml_backend_opencl_context *backend_ctx = nullptr; -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); namespace /* anonymous */ { extern struct ggml_backend_device_i ggml_backend_opencl_device_i; @@ -3554,13 +3603,13 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r /* .context = */ dev_ctx.get(), }); - if (!ggml_cl2_init(&found_devices.back())) { + if (!ggml_opencl_is_device_supported(&found_devices.back())) { found_devices.pop_back(); - GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + GGML_LOG_WARN("ggml_opencl: drop unsupported device '%s'.\n", dev->name); continue; } - dev_ctx.release(); + g_ggml_backend_opencl_dev_ctxs.push_back(std::move(dev_ctx)); } if (found_devices.size()) { @@ -3577,8 +3626,79 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +// check if device should be accepted +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + dev_ctx->gpu_family = GPU_FAMILY::ADRENO; + + // Usually device version contains the detailed device name + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (dev_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + dev_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_WARN("ggml_opencl: unsupported GPU '%s'.\n", dev_ctx->device_name.c_str()); + dev_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return false; + } + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, dev_ctx->device); + if (opencl_c_version.major < 2) { + GGML_LOG_WARN("ggml_opencl: OpenCL 2.0 or above is required\n"); + return false; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (dev_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_WARN("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return false; + } +#endif + + size_t ext_str_size; + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; + + // Check if ext_buffer contains cl_khr_fp16 + bool fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + if (!fp16_support) { + GGML_LOG_WARN("ggml_opencl: device does not support FP16\n"); + return false; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_WARN("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return false; + } + + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + return true; +} + // Initialize device if it is supported (returns nullptr if it is not). -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { GGML_ASSERT(dev); GGML_ASSERT(dev->context); @@ -3600,33 +3720,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // when the associated device is initialized backend_ctx->ref_count = 0; - if (strstr(dev_ctx->device_name.c_str(), "Adreno") || - strstr(dev_ctx->device_name.c_str(), "Qualcomm") || - strstr(dev_ctx->device_version.c_str(), "Adreno")) { - backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); - } - + backend_ctx->gpu_family = dev_ctx->gpu_family; + backend_ctx->adreno_gen = dev_ctx->adreno_gen; + if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { - backend_ctx->gpu_family = GPU_FAMILY::INTEL; - } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return nullptr; - } - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { - GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " - "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return nullptr; } -#endif // Populate backend device name backend_ctx->device_name = dev_ctx->device_name; @@ -3635,13 +3734,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { cl_device_id device = backend_ctx->device; ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); - - // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); - if (opencl_c_version.major < 2) { - GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return nullptr; - } + + backend_ctx->platform_version = platform_version; + backend_ctx->opencl_c_version = opencl_c_version; // Check driver version size_t driver_version_str_size; @@ -3664,34 +3760,21 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { char *ext_buffer = (char *)alloca(ext_str_size + 1); clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; - // fp16 is required - if (!backend_ctx->fp16_support) { - GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return nullptr; - } - - // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes - // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && - strstr(ext_buffer, "cl_intel_subgroups") == NULL) { - GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " - "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return nullptr; - } - cl_uint base_align_in_bits; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); - clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + backend_ctx->global_mem_size = dev_ctx->global_mem_size; GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); @@ -3779,8 +3862,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); - // Load kernels - load_cl_kernels(backend_ctx.get(), opencl_c_version); + // delay kernel loading until the first buffer is created + // load_cl_kernels(backend_ctx.get()); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Allocate intermediate buffers and images @@ -3822,22 +3905,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { return dev_ctx->backend_ctx; } -static void ggml_cl2_free(ggml_backend_t backend) { +static void ggml_cl_free(ggml_backend_t backend) { ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; ctx->free(); - - // The CL context is shared by all backends, release it if all backends have been released - bool should_release_opencl = true; - for (auto device : g_ggml_backend_opencl_devices) { - ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context; - if (ctx_dev->backend_ctx->ref_count > 0) { - should_release_opencl = false; - } - } - - if (should_release_opencl) { - CL_CHECK(clReleaseContext(ctx->context)); - } } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -4421,7 +4491,7 @@ static const char * ggml_backend_opencl_name(ggml_backend_t backend) { } static void ggml_backend_opencl_free(ggml_backend_t backend) { - ggml_cl2_free(backend); + ggml_cl_free(backend); } static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -4460,14 +4530,17 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { - if (g_ggml_backend_opencl_devices.size() < 2) - return; // No other devices to synchronize with. + if (g_ggml_backend_opencl_devices.size() < 2) { + return; // No other devices to synchronize with. + } std::vector events; events.reserve(g_ggml_backend_opencl_devices.size()); for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { - auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) backend_dev.context; + auto * other_backend_ctx = dev_ctx->backend_ctx; + if (backend_ctx != other_backend_ctx) { cl_event ev; CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); @@ -4880,6 +4953,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_IM2COL: return true; case GGML_OP_ARGSORT: { + load_cl_kernels_argsort(backend_ctx); + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); @@ -4897,6 +4972,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { + load_cl_kernels_flash_attn(backend_ctx); + const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; @@ -4964,7 +5041,7 @@ static ggml_backend_i ggml_backend_opencl_i = { ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(dev); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_opencl_guid(), @@ -5343,15 +5420,13 @@ static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); - return (void *) (uintptr_t) backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + return (void *) (uintptr_t) dev_ctx->backend_ctx->alignment; } static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_cl2_init(buffer->buft->device); - if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -5391,7 +5466,8 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff } static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -6626,7 +6702,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor->extra); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context *backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -7470,8 +7547,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + cl_command_queue queue = backend_ctx->queue; ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; @@ -7511,7 +7589,8 @@ static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer } static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(buffer_type->device); + load_cl_kernels(backend_ctx); // clCreateBuffer returns -61 for size 0 size = std::max(size, (size_t)1); @@ -7534,15 +7613,15 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - return backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + return dev_ctx->backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { static size_t max_size = -1; if (max_size == (size_t)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - max_size = backend_ctx->max_alloc_size; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + max_size = dev_ctx->backend_ctx->max_alloc_size; } return max_size; } @@ -7579,14 +7658,13 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; // OpenCL does not provide reliable currently-free device memory. // Use total/global memory as a best-effort upper bound. // Improved safety: Reduce by a 1GiB extra margin for common --fit - *total = backend_ctx->global_mem_size; + *total = dev_ctx->global_mem_size; *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } @@ -7610,7 +7688,7 @@ static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct } static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx = ggml_cl_init(dev); // Getting a new reference to the backend, increase ref_count backend_ctx->ref_count++; @@ -7647,6 +7725,7 @@ static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_bac } static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_cl_init(dev); return ggml_opencl_supports_op(dev, op); } @@ -7659,8 +7738,8 @@ static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggm // Check cl_context is the same. clEnqueue* commands may not use // buffers from another cl_context. - ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); - ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + ggml_backend_opencl_context * backend_ctx0 = ggml_cl_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl_init(buft->device); return backend_ctx0->context == backend_ctx1->context; } From 6d1d66de407f67ece396467da59124fe8073bd69 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 20 May 2026 22:14:13 -0700 Subject: [PATCH 122/289] hexagon: ssm-conv fix for large prompts (llama/23307) * hexagon: remove gathers and better handling of vtcm in ssm-conv * hexagon: relax ssm-conv gating requirements * hexagon: add new prefill ssm-conv backend test * hexagon: remove trailing white space * hex-rope: uninline rope_cache_init, otherwise it breaks after rebaseing with SSM_CONV changes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 7 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- ggml/src/ggml-hexagon/htp/ssm-conv.c | 388 +++++++++++++++---------- 3 files changed, 246 insertions(+), 153 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 080fb7f47e3..9db99cb0f3a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2735,9 +2735,10 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { return false; } - - // TODO: add support for non-contiguous tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) { + return false; + } + if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 9901453e91e..b398e19f06e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -107,7 +107,7 @@ static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dim cache[i0 + 1] = sinf(theta_final) * mscale_final; } -static void rope_cache_init(const float theta_base, +static __attribute__((noinline)) void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, float * corr_dims, @@ -129,7 +129,7 @@ static void rope_cache_init(const float theta_base, // pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). // sections[4]: number of head dims assigned to each position component. -static void mrope_cache_init(const float pos_t, +static __attribute__((noinline)) void mrope_cache_init(const float pos_t, const float pos_h, const float pos_w, const float pos_e, diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index a28fd03e978..d574da2e2bc 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,55 +20,56 @@ #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_ssm_conv_context { struct htp_ops_context * octx; uint32_t nrows_per_thread; + uint32_t d_inner_tile; uint64_t t_start; }; -#define htp_ssm_conv_preamble \ +#define htp_ssm_conv_preamble \ struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ - struct htp_ops_context * octx = scctx->octx; \ - htp_ssm_conv_tensors_preamble; \ - dma_queue * dma_queue = octx->ctx->dma[ith]; + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension -static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { - htp_ssm_conv_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly. +static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { + HVX_Vector tmp[32]; - const uint32_t d_conv = src1->ne[0]; - const uint32_t d_inner = src0->ne[1]; - const uint32_t n_t = dst->ne[1]; - const uint32_t n_s = dst->ne[2]; + // Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4); + tmp[2*i + 0] = Q6_V_lo_W(p); + tmp[2*i + 1] = Q6_V_hi_W(p); + } - const float * src0_data = (const float *) src0->data; - const float * src1_data = (const float *) src1->data; - float * dst_data = (float *) dst->data; + // Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m. + for (int b = 0; b < 32; b += 4) { + HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8); + HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8); + m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0); + m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1); + } - // Calculate row range for this thread - const int dr = scctx->nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const uint32_t ir = ir1 - ir0; + // Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp. + for (int b = 0; b < 32; b += 8) { + for (int i = 0; i < 4; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16); + tmp[b + 2*i + 0] = Q6_V_lo_W(p); + tmp[b + 2*i + 1] = Q6_V_hi_W(p); + } + } - if (ir0 >= ir1) { - return; // No work for this thread + // Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m. + for (int b = 0; b < 32; b += 16) { + for (int i = 0; i < 8; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32); + m[b + 2*i + 0] = Q6_V_lo_W(p); + m[b + 2*i + 1] = Q6_V_hi_W(p); + } } - // src0 and src1 gather offsets - uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; - uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + // Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64); + tmp[2 * i + 0] = Q6_V_lo_W(p); + tmp[2 * i + 1] = Q6_V_hi_W(p); + } - for (uint32_t i = 0; i < VLEN_FP32; ++i) { - src0_offsets[i] = i * (ncs) * sizeof(float); - src1_offsets[i] = i * (d_conv) * sizeof(float); + for (int i = 0; i < 32; ++i) { + m[i] = tmp[i]; } +} - const uint32_t src0_gather_len = VLEN * ncs; - const uint32_t src1_gather_len = VLEN * d_conv; +// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1 +// transposed into VTCM. +// +// VTCM layouts (per thread): +// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). +// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// +// d_inner_tile is chosen so that per-thread VTCM stays under the budget. +// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. +#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread + +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +static inline void transpose_src1(const float * src1_data, + uint32_t src1_stride_inner, + uint32_t i1_off, + uint32_t d_inner_per_thread, + uint32_t d_conv, + float * src1_T) { + for (uint32_t i = 0; i < d_inner_per_thread; ++i) { + const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; + for (uint32_t j = 0; j < d_conv; ++j) { + src1_T[j * d_inner_per_thread + i] = src_row[j]; + } + } +} - // gather scratchpads - HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); - HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); +// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM) +static inline void transpose_src0_block(const float * src0_block, + uint32_t ncs, + uint32_t cb_n, + uint32_t d_inner_tile, + float * src0_T_block_dst, + uint32_t cb /* dst column offset */) { + const uint32_t T_TILE = VLEN_FP32; + + HVX_Vector __attribute__((aligned(VLEN))) sub[32]; + + for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) { + const uint32_t t_n = MIN(T_TILE, ncs - t0); + + // Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros. + for (uint32_t r = 0; r < cb_n; ++r) { + const float * src_row = src0_block + r * ncs + t0; + if (t_n == T_TILE) { + sub[r] = *(const HVX_UVector *) src_row; + } else { + HVX_Vector v = hvx_vec_splat_f32(0.0f); + hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f)); + + float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 }; + for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k]; + v = *(const HVX_Vector *) tmp; + sub[r] = v; + } + } + for (uint32_t r = cb_n; r < T_TILE; ++r) { + sub[r] = hvx_vec_splat_f32(0.0f); + } - float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); - float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + hvx_transpose_32x32_f32(sub); - uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; - uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + // Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb. + // Only write the valid t_n rows of the transposed result. + for (uint32_t r = 0; r < t_n; ++r) { + float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb; + if (cb_n == T_TILE) { + *(HVX_UVector *) dst = sub[r]; + } else { + hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]); + } + } + } +} - // copy src1 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; - // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); - for (uint32_t i3 = 0; i3 < n_s; ++i3) { - float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + const uint32_t ncs = src0->ne[0]; - // copy src0 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); - // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + const uint32_t dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); - dma_queue_flush(dma_queue); + if (ir0 >= ir1) { + return; + } - for (uint32_t i2 = 0; i2 < n_t; ++i2) { - float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_tile = scctx->d_inner_tile; - const uint32_t nvec = ir / VLEN_FP32; - const uint32_t nloe = ir % VLEN_FP32; - uint32_t i1 = 0; + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; - for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Per-thread VTCM regions. + float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); + float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); - } + const uint32_t C_TILE = VLEN_FP32; - *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); - i1 += VLEN_FP32; - } + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) { + const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off); - if (nloe) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Place src0 chunk into VTCM in {d_inner_tile, ncs} layout. + const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb); + } - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + for (uint32_t t = 0; t < n_t; ++t) { + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + + HVX_Vector acc = hvx_vec_splat_f32(0.0f); + for (uint32_t j = 0; j < d_conv; ++j) { + HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); + } + HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); + + float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb); + if (cb_n == C_TILE) { + *(HVX_UVector *) dst_ptr = res; + } else { + hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res); + } } - - hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } @@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t use_hvx = 0; - if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { - int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && - hex_is_aligned((void *) src1->data, VLEN) && - hex_is_aligned((void *) dst->data, VLEN); - - if (is_aligned) { - use_hvx = 1; - } + if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) { + use_hvx = 1; } - if (use_hvx) { - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); - octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + const uint32_t d_inner_per_thread = scctx.nrows_per_thread; + const uint32_t ncs = src0->ne[0]; - octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256); + const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0; - // Compute gather scratchpad size for src0 and src1 - const size_t gather_spad_size = n_threads * VLEN * 2; + uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs; + d_inner_tile -= (d_inner_tile % VLEN_FP32); + if (d_inner_tile == 0) { + FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs); + use_hvx = 0; + } else { + scctx.d_inner_tile = d_inner_tile; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; + octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256); + octx->src1_spad.size_per_thread = src1_T_size; + octx->dst_spad.size_per_thread = 0; - FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, - octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, - octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = 0; - const size_t total_spad_size = - gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; - if (total_spad_size > octx->ctx->vtcm_size) { - FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, - octx->ctx->vtcm_size); + const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size; + if (total_spad > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n", + total_spad, octx->ctx->vtcm_size); use_hvx = 0; } } From 03da9f17f47e416d88deef27096291d656d7892e Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Thu, 21 May 2026 06:24:40 +0000 Subject: [PATCH 123/289] ggml : Check the right iface method before using the fallback 2d get (llama/23306) Probably no backends implement only one of 2d get/set, but this might be annoying for some future backend developer trying to add 2d get/set. --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4e36909f45e..5c0e5b1b9e2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -379,7 +379,7 @@ void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf != NULL && "tensor buffer not set"); - if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } From 158d93c8365745395da5d3f914253947d4a39e22 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 May 2026 13:34:08 +0300 Subject: [PATCH 124/289] metal : optimize concat kernel and fix set kernel threads (llama/23411) * metal : fix GGML_OP_SET kernel threads * tests : extend test_cpy to support different src/dst shapes Extend test_cpy to support different source and destination tensor shapes for CPY operations (reshaping), where the total number of elements must match. - Renamed ne -> ne_src, added ne_dst parameter (default: use src shape) - Added 50 new reshaping test cases covering 1D<->2D<->3D<->4D conversions - Tests exercise 1024 boundary, small shapes, and large dimensionality changes - Fixed dangling reference bug (storing & to temporary std::array) - Updated all existing test calls with permute/transpose args for compatibility Assisted-by: llama.cpp:local pi * metal : optimize concat kernel with row batching for small widths When ne0 < 256, batch multiple rows into a single threadgroup to improve occupancy. This avoids underutilizing the GPU when processing narrow tensors. - Dispatch nth = min(256, ne0) threads per group - Calculate nrptg (rows per threadgroup) to fill up to 256 threads - Update kernel index calculation to handle the row batching - Add boundary check for i1 >= ne1 Assisted-by: llama.cpp:local pi * tests : clean-up * tests : refactor CPY shape tests to use dimension permutations Replace 75 hardcoded test cases with a loop over permutations of {3, 5, 7, 32} (total elements: 3360). Each src permutation is tested against canonical sorted and reverse dst, skipping identical shapes. Covers F32, F16, and Q4_0 (when both src and dst ne0 == 32). Assisted-by: llama.cpp:local pi --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 19 +++++++++++++++---- ggml/src/ggml-metal/ggml-metal.metal | 6 +++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 8506000b6c0..206af227a2c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(1024, ne0); + int nth = std::min(256, ne0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + if (nth < 256) { + nrptg = std::min((256 + nth - 1) / nth, ne1); + if (nrptg * nth > 256) { + nrptg = 256 / nth; + } + } + + const int nw0 = (ne1 + nrptg - 1) / nrptg; + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1); return 1; } @@ -1786,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nk0 = ne10/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne11, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1797,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4cf9dbea946..e772664ba91 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7486,7 +7486,11 @@ kernel void kernel_concat( const int i3 = tgpig.z; const int i2 = tgpig.y; - const int i1 = tgpig.x; + const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y; + + if (i1 >= args.ne1) { + return; + } int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); From c436f1419f8128e6d0f9274bafe804d4b95fad96 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Thu, 21 May 2026 10:58:49 -0400 Subject: [PATCH 125/289] fix(flash-attn): replace f32 with kv_type and q_type (llama/23372) --- .../wgsl-shaders/flash_attn_tile.wgsl | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index ae8036b9ac5..4133f0ab564 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -122,9 +122,9 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; -var p_shmem: array; +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -169,10 +169,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let head = f32(head_idx); let slope = select(1.0, - select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), - pow(params.m0, head + 1.0), - head < params.n_head_log2), - params.max_bias > 0.0); + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_tile_row = elem_idx / HEAD_DIM_QK; @@ -181,7 +181,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = select( 0.0, - f32(Q[global_q_row_offset + q_col]) * params.scale, + Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale, head_q_row < params.seq_len_q); } @@ -213,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(k4.x); - kv_shmem[kv_off + 1u] = f32(k4.y); - kv_shmem[kv_off + 2u] = f32(k4.z); - kv_shmem[kv_off + 3u] = f32(k4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); } workgroupBarrier(); @@ -233,18 +233,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var dot_val = 0.0; for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; - let qv = vec4( + let qv = vec4( q_shmem[q_off + 0u], q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4( + let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - dot_val += dot(qv, kv); + dot_val += dot(vec4(qv), vec4(kv)); } #ifdef LOGIT_SOFTCAP dot_val = params.logit_softcap * tanh(dot_val); @@ -271,7 +271,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = p; + p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); local_sum += p; } } @@ -285,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(v4.x); - kv_shmem[kv_off + 1u] = f32(v4.y); - kv_shmem[kv_off + 2u] = f32(v4.z); - kv_shmem[kv_off + 3u] = f32(v4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); } workgroupBarrier(); @@ -308,12 +308,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4( + let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += p * v4; + acc += f32(p) * vec4(v4); } out_regs[reg_idx] = acc; } From 8402c36039c1ad14fb137bd6eecc99f2961a5a35 Mon Sep 17 00:00:00 2001 From: Pascal Date: Thu, 21 May 2026 19:39:42 +0200 Subject: [PATCH 126/289] vulkan: fuse snake activation (mul, sin, sqr, mul, add) (llama/22855) * vulkan: fuse snake activation (mul, sin, sqr, mul, add) Add snake.comp shader with F32 / F16 / BF16 pipelines and ggml_vk_snake_dispatch_fused. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. test_snake_fuse from the CUDA PR now also compares CPU naive vs Vulkan fused across F32 / F16 / BF16. * vulkan: address jeffbolznv review for fused snake activation Rename T / C to ne0 / ne1 in the shader and push constants to match the standard naming convention used across the Vulkan backend. Tighten ggml_vk_can_fuse_snake: require x and dst to be contiguous (the shader uses idx = i0 + i1 * ne0) and require a / inv_b to be tightly packed on the broadcast dim (the shader reads data_a[i1]). * vulkan: tighten snake fusion type checks for all operands (address jeffbolznv review) * vulkan: reject snake fusion when ne[2] or ne[3] > 1 (address jeffbolznv review) * vulkan: address 0cc4m review for fused snake activation snake.comp is renamed to follow the ggml DATA_A_* / A_TYPE convention. A_TYPE now applies to the activation tensor data_a instead of the broadcast multiplier, and the bindings become data_a (A_TYPE), data_b (float), data_c (float) and data_d (D_TYPE). A header at the top of the shader maps each buffer to its role in y = x + sin(b * x)^2 * c. On the C++ side, ggml_vk_can_fuse_snake reuses the existing snake_pattern constant instead of duplicating the op list, sin_node is extracted as a named local alongside the other chain nodes, and the broadcast operands a and inv_b are now required to be GGML_TYPE_F32 to match the hardcoded float bindings on data_b and data_c (the previous a->type == x->type would silently reject any future BF16 or F16 chain once the supports_op gate for SIN / SQR is lifted). ggml_vk_snake_dispatch_fused gets an explicit GGML_TYPE_F32 case and GGML_ABORT on default in place of the silent f32 fallback, and a stale comment about data_a[i1] / data_inv_b[i1] is refreshed to match the new binding names. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 136 +++++++++++++++++- .../src/ggml-vulkan/vulkan-shaders/snake.comp | 49 +++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + 3 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/snake.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d3fb19048d9..aa289220a90 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -499,6 +499,12 @@ static constexpr std::initializer_list topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -846,6 +852,9 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -1475,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -4845,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -12110,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast(x->ne[0]); + pc.ne1 = static_cast(x->ne[1]); + + std::array elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -13318,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -14691,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + // Check whether the tensors overlap in memory. // Fusions can potentially overwrite src tensors in ways that are not prevented // by ggml-alloc. If the fusion src is being applied in a way that's elementwise @@ -14998,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg op_srcs_fused_elementwise[0] = false; op_srcs_fused_elementwise[1] = false; op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -15288,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -15310,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 00000000000..8585538cbb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e3a9d61a558..a1d735150fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -952,6 +952,10 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From ec183556c6939d30272f928ac281d5390501bc51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 21 May 2026 23:35:29 +0200 Subject: [PATCH 127/289] CUDA: fix PDL CC check for JIT compilation (llama/23471) --- ggml/src/ggml-cuda/common.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9c73fe7e6fa..e54ecb29308 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1561,7 +1561,8 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); From 2d629533a5a4130ee0b91fc29637b8ffae710f6c Mon Sep 17 00:00:00 2001 From: Sachin Sharma Date: Fri, 22 May 2026 16:46:55 +0530 Subject: [PATCH 128/289] ggml-zendnn : add Q8_0 quantization support (llama/23414) * ggml-zendnn : add Q8_0 quantization support * ggml-zendnn : sync with latest ZenDNN * ggml-zendnn : address review comments for Q8_0 --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 56 ++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f1e4f991fae..e4ba9cfbd0f 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 + GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6a83bb6b1ec..6051d082003 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,6 +2,10 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "zendnnl.hpp" #include @@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() { return zendnnl::common::data_type_t::f32; } else if constexpr (std::is_same_v) { return zendnnl::common::data_type_t::bf16; + } else if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::s8; } else { return zendnnl::common::data_type_t::none; } @@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.num_threads = ctx->n_threads; zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; + + if constexpr (std::is_same_v) { + params.dtypes.compute = zendnnl::common::data_type_t::s8; + const int64_t num_groups = k / QK8_0; + params.dynamic_quant = true; + params.quant_params.src_scale.buff = nullptr; + params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16; + params.quant_params.src_scale.dims = {n, num_groups}; + params.packing.pack_format_b = 1; + } + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6 (const ggml_bf16_t *)B, ldb, (float *)C, ldc); return false; + case GGML_TYPE_Q8_0: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul( + ctx, m, n, k, + (const block_q8_0 *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); default: return false; // unsupported type } @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat( const int64_t r3 = ne13/ne03; void * work_data = ctx->work_data.get(); - if (src1->type != vec_dot_type) { + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; @@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); if (!ggml_zendnn_sgemm(ctx, ne01, // m @@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat( static_cast(dst->data) + i12*nb2 + i13*nb3, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const size_t nbw1 = row_size; const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; - const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0; + + // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes), + // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x. + const size_t f32_row_size = (size_t)ne10 * sizeof(float); + const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size; // size for MoE gather/scatter buffers - const size_t wdata_cur_size = max_rows * row_size; + const size_t wdata_cur_size = max_rows * gather_row_size; const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); // allocate single buffer for all needs @@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id( char * wdata_cur = work_data + src1_conv_size; char * dst_cur = wdata_cur + wdata_cur_size; - if (src1->type != vec_dot_type) { + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { GGML_ASSERT(src1->type == GGML_TYPE_F32); #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) @@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } } - const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; // process each expert with gather -> gemm -> scatter pattern for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { @@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const int64_t i12 = row_mapping.i2; std::memcpy( - wdata_cur + ir1 * row_size, - (const char *) wdata + (i11 + i12*ne11) * row_size, - row_size + wdata_cur + ir1 * gather_row_size, + (const char *) wdata + (i11 + i12*ne11) * gather_row_size, + gather_row_size ); } @@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( dst_cur, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: return true; default: return false; From 6fb7f1af2c66ce3512df5c9edc90caea7f6e2a5f Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Fri, 22 May 2026 08:48:24 -0400 Subject: [PATCH 129/289] SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (llama/21580) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes #20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-sycl/dmmv.cpp | 47 +++++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 8 +++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5577bf73b28..4ae431a962e 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include() + #include + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ea47f7153a..bba37a6f884 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3455,6 +3455,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3818,8 +3819,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From 0416feecfc84203d37c959493622787a419b63b0 Mon Sep 17 00:00:00 2001 From: karavayev <192749314+karavayev@users.noreply.github.com> Date: Fri, 22 May 2026 08:48:56 -0400 Subject: [PATCH 130/289] SYCL : gated_delta_net K>1 (llama/23174) * sycl_gated_delta_net K>1 * editor_config --- ggml/src/ggml-sycl/gated_delta_net.cpp | 91 +++++++++++++++++++------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index ebc587524bf..9c2449aba0c 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -6,7 +6,7 @@ #include -template +template void gated_delta_net_sycl(const float * q, const float * k, const float * v, @@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q, int64_t sb3, const sycl::uint3 neqk1_magic, const sycl::uint3 rq3_magic, - float scale) { + float scale, + int K) { auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const uint32_t h_idx = item_ct1.get_group(2); const uint32_t sequence = item_ct1.get_group(1); @@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; @@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q, } attn_data += S_v * H; - } + // Write state back to global memory + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net(const float * q_d, const float * k_d, const float * v_d, @@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d, int64_t neqk1, int64_t rq3, float scale, + int K, dpct::queue_ptr stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; @@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } From 21c65a78a3472275db7ccdc40078fa4a3858724a Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:49:45 +0900 Subject: [PATCH 131/289] sycl : Level Zero detection in ggml_sycl_init (llama/23097) * [SYCL] Centralize Level Zero detection in ggml_sycl_init * use the same wording * get back the warning --- ggml/src/ggml-sycl/common.hpp | 2 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 26 ++++++++------------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 96bc1c98bd9..6d19538215e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -238,6 +238,8 @@ struct ggml_sycl_device_info { std::array default_tensor_split = {}; int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; + + bool ext_oneapi_level_zero = true; // sycl::backend::ext_oneapi_level_zero used by all enumerated GPU devices }; const ggml_sycl_device_info & ggml_sycl_info(); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bba37a6f884..46795f43602 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -98,7 +98,7 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; - sycl::device device = dpct::dev_mgr::instance().get_device(i); + auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); @@ -117,6 +117,12 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); info.devices[i].hw_info = get_device_hw_info(&device); + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + if (device.is_gpu() && device.default_queue().get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + info.ext_oneapi_level_zero = false; + } } for (int id = 0; id < info.device_count; ++id) { @@ -230,26 +236,10 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO - g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); #else g_ggml_sycl_enable_level_zero = 0; #endif - if (g_ggml_sycl_enable_level_zero) { - // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. - // Only check GPU devices; CPU devices use OpenCL and would otherwise - // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. - for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { - auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); - if (!q.get_device().is_gpu()) { - continue; - } - if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { - GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); - g_ggml_sycl_enable_level_zero = 0; - break; - } - } - } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); From b0c9f9005926301e8a8306bd300e86dbd09db41d Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:50:17 +0900 Subject: [PATCH 132/289] SYCL: improve MoE prefill throughput (llama/23142) - change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity --- ggml/src/ggml-sycl/ggml-sycl.cpp | 195 +++++++++++++++++-------------- 1 file changed, 105 insertions(+), 90 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 46795f43602..b3fbb621196 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3919,35 +3919,17 @@ struct mmid_row_mapping { __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, - int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, - const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, - const sycl::nd_item<3> &item_ct1, int &src1_row) { - int32_t iid1 = item_ct1.get_group(2); - int32_t id = item_ct1.get_group(1); - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const sycl::nd_item<3> &item_ct1) { + const int32_t src1_row = item_ct1.get_group(2); - if (row_id_i != i02) { - return; - } + const int32_t iid1 = row_mapping[src1_row].i2; + const int32_t id = row_mapping[src1_row].i1; const int64_t i11 = id % ne11; const int64_t i12 = iid1; - if (item_ct1.get_local_id(2) == 0) { - src1_row = - dpct::atomic_fetch_add( - cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); @@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused( src1_row_stride, stream); } +// counting sort of the routed rows by expert id (row_id_i, as chosen by the router): +// builds a projection of a memory layout where each expert's slice is contiguous +static void mmid_counting_sort_rows( + const ggml_tensor * ids, const char * ids_host, + int64_t n_ids, int64_t n_as, int64_t n_routed_rows, + std::vector & expert_counts, + std::vector & expert_row_offsets, + std::vector & routed_row_src) { + + // frequencies: how many routed rows each expert "owns" + expert_counts.assign(n_as, 0); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + expert_counts[row_id_i]++; + } + } + + // where each expert's slice starts (row indices) and the previous ends + expert_row_offsets.assign(n_as + 1, 0); + for (int64_t i02 = 0; i02 < n_as; i02++) { + expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02]; + } + + std::vector expert_row_next = expert_row_offsets; + routed_row_src.resize(n_routed_rows); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + // find and validate the next free row for a given expert (row_id_i) + const int64_t routed_row = expert_row_next[row_id_i]++; + GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]); + GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]); + routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1}; + } + } +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + // how many "owned" routed rows to pass to each expert + std::vector expert_row_counts; + // where each expert's slice starts and the previous ends (row indices, right-exclusive) + std::vector expert_row_offsets; + // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous + std::vector routed_row_src; - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, + expert_row_counts, expert_row_offsets, routed_row_src); - if (row_id_i != i02) { - continue; - } + ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), n_routed_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping)))); - num_src1_rows++; - } - } + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_row_mapping_get, + ne11, ne10, nb11, nb12, + item_ct1); + }); + }); + } + + for (int64_t i02 = 0; i02 < n_as; i02++) { + const int64_t num_src1_rows = expert_row_counts[i02]; if (num_src1_rows == 0) { continue; } - - ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); - - const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); - - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); - sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor src1_row_acc(cgh); - - char *__restrict src1_contiguous_get = - src1_contiguous.get(); - int *__restrict dev_cur_src1_row_get = - dev_cur_src1_row.get(); - mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - size_t ids_nb_ct6 = ids->nb[1]; - size_t ids_nb_ct7 = ids->nb[0]; - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_src1_to_contiguous( - src1_original, src1_contiguous_get, - dev_cur_src1_row_get, - dev_row_mapping_get, ids_dev, i02, - ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, - item_ct1, src1_row_acc); - }); - }); - } + const int64_t expert_row_offset = expert_row_offsets[i02]; src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.data = src1_contiguous.get() + expert_row_offset*nb11; src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.data = dst_contiguous.get() + expert_row_offset*nb1; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); - sycl::range<3> grid_dims(1, 1, num_src1_rows); - stream->submit([&](sycl::handler &cgh) { - const char *__restrict dst_contiguous_get = - dst_contiguous.get(); - const mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_dst_from_contiguous(dst_original, - dst_contiguous_get, - dev_row_mapping_get, - ne0, nb1, nb2, item_ct1); - }); - }); - } + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } From aefffa1fa5a7aadeb138106645bb5b8e6d8370d3 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 22 May 2026 17:08:41 -0700 Subject: [PATCH 133/289] opencl: generalize Adreno MoE kernels on M (llama/23449) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 18 +++--- ggml/src/ggml-opencl/kernels/cvt.cl | 64 +++++++++++++++++++ .../kernels/gemm_moe_mxfp4_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q6_k_f32_ns.cl | 6 +- .../kernels/gemv_moe_mxfp4_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 4 ++ 18 files changed, 145 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 5fc46f789ec..ea0b44feea2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4693,7 +4693,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -14297,7 +14297,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14513,7 +14513,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14689,7 +14689,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14865,7 +14865,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15118,7 +15118,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15291,7 +15291,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15469,7 +15469,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15644,7 +15644,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 312366984b6..c25eabdd72b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -679,6 +711,10 @@ kernel void kernel_convert_block_q4_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -732,6 +768,10 @@ kernel void kernel_restore_block_q4_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -784,6 +824,10 @@ kernel void kernel_convert_block_q5_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -850,6 +894,10 @@ kernel void kernel_restore_block_q5_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -916,6 +964,10 @@ kernel void kernel_convert_block_q6_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; @@ -993,6 +1045,10 @@ kernel void kernel_restore_block_q6_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1147,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1190,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl index e404f392bdd..02cdbdd9fb1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl index 02290c17eb1..d403ed0cab1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl index e2574ae0187..b2bddf3f73a 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl index 9d24aff6a20..ab8228d18ca 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -133,7 +133,7 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -225,6 +225,10 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl index 3524cb1bdbd..d1a35d58bb2 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl index 5fc2a523234..90d345ecf51 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl index 808a0c7db6a..13c26f6f3b6 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -134,7 +134,7 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -230,6 +230,10 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl index a040335adfa..85ccebec78c 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -209,6 +209,10 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl index e4b44c1a56a..75129e20c65 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -82,6 +82,10 @@ __kernel void kernel_gemv_moe_mxfp4_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl index 6f4d3f53216..2d28db63ec5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -37,6 +37,10 @@ __kernel void kernel_gemv_moe_q4_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl index 3739a215705..b98bdc0f12e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q4_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl index 13d79f2526f..12464e9826e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -54,6 +54,10 @@ __kernel void kernel_gemv_moe_q4_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl index 938054cf982..b43613638a8 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q5_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl index f33a4ef2757..7a666006e68 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -39,6 +39,10 @@ __kernel void kernel_gemv_moe_q5_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl index f128d44340a..7d868d7abd9 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -55,6 +55,10 @@ __kernel void kernel_gemv_moe_q5_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl index 526e609dc3a..c166bad5ba5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q6_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; From 6b85d73b33f497e2c37daf2ff0da15a0d7fdfcaf Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 May 2026 02:44:46 -0500 Subject: [PATCH 134/289] vulkan: fix windows find_package of SPIRV-Headers (llama/23215) * vulkan: fix windows find_package of SPIRV-Headers * not windows-only --- ggml/src/ggml-vulkan/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 6dbcea065b3..65785ae4566 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,7 +8,10 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) -find_package(SPIRV-Headers REQUIRED) +if (DEFINED ENV{VULKAN_SDK}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{VULKAN_SDK}") +endif() +find_package(SPIRV-Headers CONFIG REQUIRED) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files From 511f8602b14fc0c5b32b5163f4f820aac8e6ee8e Mon Sep 17 00:00:00 2001 From: dskwe Date: Sat, 23 May 2026 18:49:24 +0800 Subject: [PATCH 135/289] ggml : Check the right iface method before using the fallback 2d get (llama/23514) --- ggml/src/ggml-backend.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 5c0e5b1b9e2..87615921c09 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -306,7 +306,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + if (n_copies <= 1 || backend->iface.get_tensor_2d_async == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } @@ -317,7 +317,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ } GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); } From b84d03487c0c1cda50cfc903d281ae4a06675ae4 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Sat, 23 May 2026 19:56:59 -0700 Subject: [PATCH 136/289] hexagon: apply repl optimization in flash attn softmax as #22993 (llama/23455) --- ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 4a4ff0b331d..9e1b778b01f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. - // vror brings the target lane to lane 0, then extract + re-splat. - HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); - HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + // vror brings the target lane to lane 0, then vdelta replicates it + // across all lanes — stays in the vector domain (no store/reload). + HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)); + HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)); // HVX max — both operands are splats, so result is splat of m_new. HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); From 1435988ab360285c975bbc812f2b85d0f9a5b5dd Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sat, 23 May 2026 23:11:43 -0700 Subject: [PATCH 137/289] opencl: batch profiling to improve speed and prevent memory leaks (llama/23495) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 36 +++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ea0b44feea2..42286435bc6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -661,11 +661,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; + std::vector profiling_results; - void write_profiling_info() { - FILE * fperf = fopen("cl_profiling.csv", "w"); - if (!fperf) { - GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + void flush_profiling_batch() { + if (profiling_info.empty()) { return; } @@ -689,6 +688,7 @@ struct ggml_backend_opencl_context { CL_CHECK(clGetEventProfilingInfo( info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); CL_CHECK(clReleaseEvent(info.evt)); + info.evt = nullptr; char kernel_name[512]; CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, @@ -706,10 +706,26 @@ struct ggml_backend_opencl_context { info.cmd_complete_duration_ns = cmd_complete - cmd_end; info.cmd_total_duration_ns = cmd_complete - cmd_queued; } + profiling_results.insert(profiling_results.end(), + std::make_move_iterator(profiling_info.begin()), + std::make_move_iterator(profiling_info.end())); + profiling_info.clear(); + } + + void write_profiling_info() { + if (profiling_results.empty()) { + return; + } // Dump a csv + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), info.cmd_duration_ns/1.e6f, @@ -720,14 +736,14 @@ struct ggml_backend_opencl_context { fclose(fperf); // Dump a simple chrome trace - FILE* ftrace = fopen("cl_trace.json", "w"); + FILE * ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { GGML_LOG_ERROR("Failed to open cl_trace.json\n"); return; } fprintf(ftrace, "[\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", @@ -738,6 +754,7 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } + fprintf(ftrace, "]\n"); fclose(ftrace); } @@ -758,6 +775,9 @@ struct ggml_backend_opencl_context { profiling_info.emplace_back(); populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor); + if (profiling_info.size() >= 2048) { + flush_profiling_batch(); + } #else GGML_UNUSED(tensor); CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL)); @@ -804,7 +824,7 @@ struct ggml_backend_opencl_context { if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); - profiling_info.clear(); + profiling_results.clear(); #endif } } From 3306af62b1e3ad2a32f461700f7cccfd531fe08f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 24 May 2026 08:19:33 +0200 Subject: [PATCH 138/289] TP: fix entirely zero-sized slices per device (llama/23525) --- ggml/include/ggml-alloc.h | 1 + ggml/src/ggml-backend-meta.cpp | 36 ++++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 78aa059dde3..a7926a21a9a 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -76,6 +76,7 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +// ggml_backend_alloc_ctx_tensors_from_buft returns NULL on failure or if all tensors in ctx are already allocated or zero-sized GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index df0f405ed9f..5f9ae9c1bc5 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1275,6 +1275,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1382,6 +1385,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co for (size_t j = 0; j < n_bufs; j++){ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1445,6 +1451,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + GGML_ASSERT(simple_buf != nullptr); max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); } @@ -1474,8 +1481,27 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( - meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); + + // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. + // For those edge cases, allocate a dummy buffer instead. + bool any_nonzero_slice = false; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (ggml_nelements(t) != 0) { + any_nonzero_slice = true; + break; + } + } + if (any_nonzero_slice) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + } else { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf_ctx->buf_configs[i].buf; + } + } + GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); } return meta_buf; @@ -1605,6 +1631,9 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1646,6 +1675,9 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; From a369b3949c2f4f624bb8a0324cea870661581194 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 25 May 2026 02:15:46 -0500 Subject: [PATCH 139/289] ggml : Parallelize quant LUT init (llama/23595) - Use OpenMP to parallelize iq2xs_init_impl and iq3xs_init_impl. - Move the OpenMP detection from ggml-cpu to ggml-base. - Update OpenMP dependencies in ggml-config.cmake.in. --- ggml/cmake/ggml-config.cmake.in | 14 +- ggml/src/CMakeLists.txt | 17 ++ ggml/src/ggml-cpu/CMakeLists.txt | 14 +- ggml/src/ggml-quants.c | 328 ++++++++++++++++++++----------- 4 files changed, 246 insertions(+), 127 deletions(-) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 91c9d5cd343..23a3066f56d 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -6,6 +6,7 @@ include(CMakeFindDependencyMacro) find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_BASE_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_OPTIONS "") @@ -20,7 +21,15 @@ if (NOT GGML_SHARED_LIB) if (GGML_OPENMP_ENABLED) find_dependency(OpenMP) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + set(GGML_OPENMP_INTERFACE_LINK_LIBRARIES "") + if (TARGET OpenMP::OpenMP_C) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C) + endif() + if (TARGET OpenMP::OpenMP_CXX) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) + endif() + list(APPEND GGML_BASE_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) endif() if (GGML_CPU_HBM) @@ -122,7 +131,8 @@ if(NOT TARGET ggml::ggml) add_library(ggml::ggml-base UNKNOWN IMPORTED) set_target_properties(ggml::ggml-base PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}" + INTERFACE_LINK_LIBRARIES "${GGML_BASE_INTERFACE_LINK_LIBRARIES}") set(_ggml_all_targets "") if (NOT GGML_BACKEND_DL) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3e48860bfc8..c26c3f1470d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,23 @@ if (GGML_SCHED_NO_REALLOC) target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) endif() +if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") + else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") + message(WARNING "OpenMP not found") + endif() +else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") +endif() + +if (GGML_OPENMP_ENABLED) + target_compile_definitions(ggml-base PRIVATE GGML_USE_OPENMP) + target_link_libraries(ggml-base PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) +endif() + add_library(ggml ggml-backend-dl.cpp ggml-backend-reg.cpp) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f3eccff7d72..8c735a045b3 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -72,17 +72,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) - - target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") - message(WARNING "OpenMP not found") - endif() + if (GGML_OPENMP_ENABLED) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a..15d231f70c0 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13,6 +13,10 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef GGML_USE_OPENMP +#include +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -3064,70 +3068,121 @@ void iq2xs_init_impl(enum ggml_type type) { } kmap_q2xs[index] = i; } - int8_t pos[8]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // The neighbour search runs in three passes: + // 1. Parallel: for each i, qsort and count its neighbours into n_per_i, + // and reduce the totals (num_neighbors, num_not_in_map). + // 2. Serial: prefix-sum n_per_i into offsets[], so each i has a + // pre-assigned slice of kneighbors_q2xs to write into. + // 3. Parallel: redo the qsort and write each i's neighbour list at + // offsets[i]. + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq2_data[gindex].neighbours = kneighbors_q2xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - kmap_q2xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q2xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q2xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q2xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int local_counter = offsets[i]; + kmap_q2xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq2xs_free_impl(enum ggml_type type) { @@ -3663,70 +3718,115 @@ void iq3xs_init_impl(int grid_size) { } kmap_q3xs[index] = i; } - int8_t pos[4]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // See explanation of parallelism in iq2xs_init_impl + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq3_data[gindex].neighbours = kneighbors_q3xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - kmap_q3xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q3xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q3xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q3xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int local_counter = offsets[i]; + kmap_q3xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq3xs_free_impl(int grid_size) { From 946d6813b9d999008c22a3c3b11332bac6eece1f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:13:21 +0300 Subject: [PATCH 140/289] ggml : bump version to 0.12.1 (ggml/1508) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4aac5094d1c..03020888f97 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 12) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0a62a579ccdb649a321fd7a04c2f874d8cd5257a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:14:40 +0300 Subject: [PATCH 141/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5a605ba344e..2c680ce9f5d 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -57ea0bc119d722d74594196cc5b494a34dd87be4 +0a37c2167fc5b81830a32d9b1691610180ed86d6 From 865ec171aa83625a388bce0b43f091bb3054f56b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:18:31 +0300 Subject: [PATCH 142/289] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 27 +- examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 8 +- examples/talk-llama/llama-chat.h | 2 +- examples/talk-llama/llama-context.cpp | 196 ++++++++- examples/talk-llama/llama-context.h | 9 + examples/talk-llama/llama-cparams.h | 4 + examples/talk-llama/llama-ext.h | 16 + examples/talk-llama/llama-graph.cpp | 43 +- examples/talk-llama/llama-graph.h | 6 +- examples/talk-llama/llama-hparams.cpp | 6 + examples/talk-llama/llama-hparams.h | 2 + .../talk-llama/llama-memory-hybrid-iswa.cpp | 14 +- .../talk-llama/llama-memory-hybrid-iswa.h | 1 + examples/talk-llama/llama-memory-hybrid.cpp | 14 +- examples/talk-llama/llama-memory-hybrid.h | 1 + .../talk-llama/llama-memory-recurrent.cpp | 131 +++++- examples/talk-llama/llama-memory-recurrent.h | 9 + examples/talk-llama/llama-memory.h | 3 + examples/talk-llama/llama-model-loader.cpp | 13 +- examples/talk-llama/llama-model-loader.h | 2 +- examples/talk-llama/llama-model-saver.cpp | 2 + examples/talk-llama/llama-model.cpp | 57 ++- examples/talk-llama/llama-model.h | 21 +- examples/talk-llama/llama-vocab.cpp | 125 +++++- examples/talk-llama/llama.h | 11 +- examples/talk-llama/models/afmoe.cpp | 2 +- examples/talk-llama/models/apertus.cpp | 2 +- examples/talk-llama/models/arcee.cpp | 2 +- examples/talk-llama/models/arctic.cpp | 2 +- examples/talk-llama/models/arwkv7.cpp | 2 +- examples/talk-llama/models/baichuan.cpp | 2 +- examples/talk-llama/models/bailingmoe.cpp | 2 +- examples/talk-llama/models/bailingmoe2.cpp | 2 +- examples/talk-llama/models/bloom.cpp | 2 +- examples/talk-llama/models/chameleon.cpp | 2 +- examples/talk-llama/models/chatglm.cpp | 2 +- examples/talk-llama/models/codeshell.cpp | 2 +- examples/talk-llama/models/cogvlm.cpp | 2 +- examples/talk-llama/models/cohere2.cpp | 2 +- examples/talk-llama/models/command-r.cpp | 2 +- examples/talk-llama/models/dbrx.cpp | 2 +- examples/talk-llama/models/deci.cpp | 2 +- examples/talk-llama/models/deepseek.cpp | 2 +- examples/talk-llama/models/delta-net-base.cpp | 164 +++++++- examples/talk-llama/models/dots1.cpp | 2 +- examples/talk-llama/models/dream.cpp | 2 +- examples/talk-llama/models/ernie4-5-moe.cpp | 2 +- examples/talk-llama/models/ernie4-5.cpp | 2 +- examples/talk-llama/models/exaone-moe.cpp | 2 +- examples/talk-llama/models/exaone.cpp | 2 +- examples/talk-llama/models/exaone4.cpp | 2 +- examples/talk-llama/models/falcon-h1.cpp | 2 +- examples/talk-llama/models/falcon.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n.cpp | 2 +- examples/talk-llama/models/gemma4.cpp | 2 +- examples/talk-llama/models/glm4-moe.cpp | 2 +- examples/talk-llama/models/glm4.cpp | 2 +- examples/talk-llama/models/gpt2.cpp | 2 +- examples/talk-llama/models/gptneox.cpp | 2 +- examples/talk-llama/models/granite-hybrid.cpp | 2 +- examples/talk-llama/models/granite.cpp | 2 +- examples/talk-llama/models/grok.cpp | 2 +- examples/talk-llama/models/grovemoe.cpp | 2 +- examples/talk-llama/models/hunyuan-moe.cpp | 2 +- examples/talk-llama/models/hunyuan-vl.cpp | 2 +- examples/talk-llama/models/internlm2.cpp | 2 +- examples/talk-llama/models/jais.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 2 +- examples/talk-llama/models/jamba.cpp | 2 +- examples/talk-llama/models/lfm2.cpp | 2 +- examples/talk-llama/models/llada-moe.cpp | 2 +- examples/talk-llama/models/llada.cpp | 2 +- examples/talk-llama/models/llama.cpp | 2 +- examples/talk-llama/models/llama4.cpp | 2 +- examples/talk-llama/models/maincoder.cpp | 2 +- examples/talk-llama/models/mamba.cpp | 2 +- examples/talk-llama/models/mimo2.cpp | 2 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 2 +- examples/talk-llama/models/models.h | 33 +- examples/talk-llama/models/mpt.cpp | 2 +- examples/talk-llama/models/nemotron-h.cpp | 2 +- examples/talk-llama/models/nemotron.cpp | 2 +- examples/talk-llama/models/olmo.cpp | 2 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 2 +- examples/talk-llama/models/openai-moe.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/orion.cpp | 2 +- examples/talk-llama/models/paddleocr.cpp | 2 +- examples/talk-llama/models/pangu-embed.cpp | 2 +- examples/talk-llama/models/phi2.cpp | 2 +- examples/talk-llama/models/phi3.cpp | 2 +- examples/talk-llama/models/plamo.cpp | 2 +- examples/talk-llama/models/plamo2.cpp | 2 +- examples/talk-llama/models/plamo3.cpp | 2 +- examples/talk-llama/models/plm.cpp | 2 +- examples/talk-llama/models/qwen.cpp | 2 +- examples/talk-llama/models/qwen2.cpp | 2 +- examples/talk-llama/models/qwen2moe.cpp | 2 +- examples/talk-llama/models/qwen2vl.cpp | 2 +- examples/talk-llama/models/qwen3.cpp | 2 +- examples/talk-llama/models/qwen35.cpp | 321 +++++++++++---- examples/talk-llama/models/qwen35moe.cpp | 373 ++++++++++++++---- examples/talk-llama/models/qwen3moe.cpp | 2 +- examples/talk-llama/models/qwen3next.cpp | 46 +-- examples/talk-llama/models/qwen3vl.cpp | 2 +- examples/talk-llama/models/qwen3vlmoe.cpp | 2 +- examples/talk-llama/models/refact.cpp | 2 +- examples/talk-llama/models/rnd1.cpp | 2 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv6qwen2.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 2 +- examples/talk-llama/models/smallthinker.cpp | 2 +- examples/talk-llama/models/smollm3.cpp | 2 +- examples/talk-llama/models/stablelm.cpp | 2 +- examples/talk-llama/models/starcoder.cpp | 2 +- examples/talk-llama/models/starcoder2.cpp | 2 +- examples/talk-llama/models/step35.cpp | 2 +- examples/talk-llama/models/t5.cpp | 2 +- .../talk-llama/models/wavtokenizer-dec.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 2 +- examples/talk-llama/unicode.cpp | 133 +++++++ 129 files changed, 1593 insertions(+), 395 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 59dde99e362..c9eead18aa3 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e37d548c98e..89cf16cc37c 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 6554a89b28a..f10397747b0 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, - { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -218,7 +218,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -825,8 +825,8 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { - // tencent/HunyuanOCR + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL ss << "<|hy_begin▁of▁sentence|>"; for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 13f936a946c..ea6540c0be7 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -53,7 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, - LLM_CHAT_TEMPLATE_HUNYUAN_OCR, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 3d9714ab166..ad36c06667d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,13 +51,22 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +83,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +226,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +299,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +882,42 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1098,13 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1072,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1241,7 +1319,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1392,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1460,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1622,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_ubatch(); @@ -1689,7 +1783,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1727,8 +1822,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1905,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1823,6 +1938,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1893,10 +2009,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +2026,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2050,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2067,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2096,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2163,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2256,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3235,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3336,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3443,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3491,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3584,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 92d1b0cf95a..d03f681d4a1 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 9d359474132..20ec59fe335 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +43,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8ce29d217cb..edfa71c207c 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fe155c92dea..fc027de8b39 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -500,15 +500,21 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -534,14 +540,21 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -848,6 +861,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -2528,7 +2544,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5cb1756c6a9..bf6778237e6 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -580,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -644,6 +646,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +675,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 10e6b459797..72f5c2fea72 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.h +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 4ce1af592c1..33b3b395e0c 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index c07f1d969cb..ec5dc5835dd 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -388,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -703,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +790,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +801,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +882,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +904,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d73912..b13b7b748f5 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index e83056557bf..528e4c9c069 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model->output); add_tensor(model->output_b); add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); add_tensor(model->cls); add_tensor(model->cls_b); add_tensor(model->cls_out); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ff30a2ae7a6..0d21b2a53c5 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1334,6 +1334,12 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // input scales if (!layer.wq_in_s && layer.wq) { @@ -1393,11 +1399,30 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_in_s && layer.ssm_beta) { layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } } } - ml.done_getting_tensors(); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { @@ -1934,6 +1959,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1942,8 +1973,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1958,6 +1990,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1975,6 +2015,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -1993,6 +2034,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2000,6 +2042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2011,6 +2054,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2026,7 +2074,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2043,7 +2091,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2146,6 +2194,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d63c689185a..398a0aa725c 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -202,12 +202,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -533,6 +537,11 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index f43cf546ca0..a5cf148b268 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -530,6 +530,8 @@ struct llm_tokenizer_bpe : llm_tokenizer { struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } @@ -567,7 +569,7 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector & output) { + virtual void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); @@ -1579,6 +1581,88 @@ struct llm_tokenizer_plamo2_session { const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector & output) override { + static const std::string open_tag = ""; + static const std::string close_tag = ""; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token(""); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + // // impl // @@ -1808,7 +1892,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2266,6 +2350,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after , so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find(""); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -3144,11 +3245,19 @@ std::vector llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast(tokenizer.get()); + + std::unique_ptr session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique(vocab, *tok_bpe); + } else { + session = std::make_unique(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -3161,15 +3270,15 @@ std::vector llama_vocab::impl::tokenize( #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 308e8ba9dbd..e8374c53b70 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -866,7 +874,8 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 -// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. #define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 typedef uint32_t llama_state_seq_flags; diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 602e3176afd..a7c77ee5d28 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 136ff702957..bec7136521c 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 70e86d41130..d086c4717ff 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index d8653a44639..27deadffeb7 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 79aa8c90899..9bd04127b25 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4e55290e4e5..4d26081cd5d 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 030dd4f42a4..fe1ae10864b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index e7fe3d5b45a..2f0d44a6259 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b600fb0c954..30b0f3d07d0 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 8510b9e29f8..4bceaefd63b 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index e898eff7939..6766fa71c15 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e9e85d96713..274dd3342a7 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 79236121bd5..2e231bb3f93 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index 12edbae1094..a514cf88fc6 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index decb89f547b..adf7fcaa20f 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index bce6b04bcf9..af71c775365 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 9f1a959c32c..567e3535276 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index c7946059662..f52ec9518b6 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 6bc989c9509..4f4c7cac7a8 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,162 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 93cbcf9d931..435d27281c6 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 60a3f0ec285..12ac6f1ce88 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 2bd01a2c512..8d9ff138676 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index fa989fe92cd..9b39c605e35 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 54bb3ca86b3..76d91982fc5 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 75d5f60631c..c7e9960d718 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 5506e76424d..499e22dde81 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -163,7 +163,7 @@ llama_model_exaone4::graph::graph(const llama_model & model, const llm_gra res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index d353befdb8e..94b65a3c7c9 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 75f2cfef560..ad546ef2db5 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 06731670007..1519682fdf6 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index 6255bf740fc..ae3f9ffb530 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index ee510fe38b0..63a2b380e71 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -207,7 +207,7 @@ llama_model_gemma3::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 881499b0ca7..6ec3a006081 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index f45ae4cad59..4f9d8b18bc7 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 45886b51ac1..27654b8cba3 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index d6ef76e26d6..7c242fed298 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index ba49c31b56b..e2dcc8b1521 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 33ebe2d8800..443e35addf2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 12e4790ae24..27f6706ea10 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 5e7c7b68181..cda4aa231fa 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -145,7 +145,7 @@ llama_model_granite::graph::graph( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 0bc49d00206..7c46ec1c0f2 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index feef815165b..1cab75adc7f 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 44af42412f7..deb3c9671f3 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp index 5fb9154bec0..da9bb74de7e 100644 --- a/examples/talk-llama/models/hunyuan-vl.cpp +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f0c5580a6f4..f9ee37a24b6 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index a6451dca095..2ba162605f1 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index ad59b953e8d..8966131441c 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index e1b8d137e38..84ea63c3136 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index df6a8028736..29081344b24 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -262,7 +262,7 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index b60f67f6c4b..9722dde9f17 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index fa21c5fe32c..58b2c466e17 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8ddb5936820..cef66d054b0 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -235,7 +235,7 @@ llama_model_llama::graph::graph(const llama_model & model, const llm_grap if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 899611d53f6..0ff5376d571 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -260,7 +260,7 @@ llama_model_llama4::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 3dbd82fd362..84cfe399027 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index b7708d7fdd1..887a1fa509a 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index 71996616611..d0295ec116f 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index ff5eb6ffa5f..1ffc54fa7c6 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 0dee8934692..22e291d73a3 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 708da49af1f..4e6ebef82cb 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6d5f18a8e20..7e551eb965b 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,29 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1762,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1808,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index cfc60e8de29..0229d20ed36 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 865461f61db..a82f9c170b4 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 0c72ed297aa..5d4a3b5c69e 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 161035e72bc..cfcf17bcb03 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 9633f269965..7cc262f5504 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -198,7 +198,7 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 4bb9013054c..7976ae44a51 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 13a590ce646..15b6c8c1205 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index b4128e116e7..9f76350fd4d 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 7ace0a5139d..bcb4bbba4b1 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 1c0eadefa98..d39220bd778 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 41b7e2ac23e..7593f879b24 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index a333602c72d..8f3ed5f7b7d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 0a65e91fefa..f8a4a4d5aa5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -179,7 +179,7 @@ llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4c16c20a0d4..c7ed1211c31 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 29c8702606a..b713889fe72 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 849f1579e63..29f3e803d68 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -186,7 +186,7 @@ llama_model_plamo3::graph::graph(const llama_model & model, const llm_grap cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 57f5995103b..ce050919e6a 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index cdc076cdf77..00467dbad7d 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 6320458a13b..a5147460bae 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7587c802c68..7cb03859deb 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 1a40fa89be4..d79db682cd4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index fa656c84ea0..41b97fed956 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index f276be61ba8..04ecc18fcdc 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,13 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -167,7 +222,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -297,8 +352,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +381,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +439,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +450,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +486,151 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index cf05dc9d61c..dc24f6ed537 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,13 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -180,7 +245,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -310,8 +375,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +404,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +462,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +550,183 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 4440b83aa45..a4f8e1379c9 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cb1b4814caf..1d873427db5 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 7871f8f7952..5defd893944 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index b99143c8908..5b77df57122 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index f14f10917ff..bf3949a9092 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 325ee73ba5c..ca8e009615e 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 2944711acec..ba2a9dfa0db 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 6f7d1f5722f..566b8cdcb54 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index b205e3935e1..7574b252621 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 83e114740b6..806cba574be 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 3214e7cbad3..4231cccc666 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -178,7 +178,7 @@ llama_model_smallthinker::graph::graph(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 7adaf34c534..90e7d473eaf 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 8f613e55947..4da7f7aefcf 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 58cf0ac0edc..e131af058bc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 45dae0602d4..9c207c02885 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index c4789752d21..3b68e68707a 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 27a0711ba41..73e32741406 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -265,7 +265,7 @@ llama_model_t5::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a873e5d2e8f..214fed99bad 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -253,7 +253,7 @@ llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_ LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index e4d111e622a..d6d1c7a2e5d 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -126,7 +126,7 @@ llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index dc13e53f09f..b02ecdc930f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -605,6 +605,136 @@ static std::vector unicode_regex_split_custom_qwen2(const std::string & return bpe_offsets; } +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector unicode_regex_split_custom_qwen35(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -929,6 +1059,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); From 44a50ca41a574d596eef4b2d0a4ffcc1575d6000 Mon Sep 17 00:00:00 2001 From: Kaihui-AMD Date: Mon, 25 May 2026 17:27:42 +0800 Subject: [PATCH 143/289] readme : add AMD ROCm/HIP GPU build instructions (#3823) Signed-off-by: Kaihui-AMD --- README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.md b/README.md index 474a1301da7..050a35be21c 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - [Vulkan support](#vulkan-gpu-support) - Support for CPU-only inference - [Efficient GPU support for NVIDIA](#nvidia-gpu-support) +- [AMD ROCm GPU support](#amd-rocm-gpu-support) - [OpenVINO Support](#openvino-support) - [Ascend NPU Support](#ascend-npu-support) - [Moore Threads GPU Support](#moore-threads-gpu-support) @@ -340,6 +341,27 @@ cmake -B build -DGGML_VULKAN=1 cmake --build build -j --config Release ``` +## AMD ROCm GPU support + +With AMD GPUs the processing can be accelerated via HIP/ROCm. +First, make sure you have installed [ROCm](https://rocm.docs.amd.com/en/latest/). + +Now build `whisper.cpp` with HIP support: + +``` +cmake -B build -DGGML_HIP=1 -DAMDGPU_TARGETS="gfx1201" +cmake --build build -j --config Release +``` + +Replace `gfx1201` with your GPU architecture. You can find it with: + +``` +rocminfo | grep "gfx" +``` + +Common architectures: `gfx1100` (RX 7900 XTX), `gfx1101` (RX 7800 XT), `gfx1201` (RX 9070 XT). +For multiple GPUs with different architectures: `-DAMDGPU_TARGETS="gfx1100;gfx1201"`. + ## BLAS CPU support via OpenBLAS Encoder processing can be accelerated on the CPU via OpenBLAS. From 2979e5f95fa95c4f30fd36eb5e4766448da9da44 Mon Sep 17 00:00:00 2001 From: Gilad S <7817232+giladgd@users.noreply.github.com> Date: Mon, 25 May 2026 11:33:29 +0200 Subject: [PATCH 144/289] ggml: `gguf_init_from_callback` and `gguf_init_from_buffer` (llama/22341) * ggml: implement `gguf_init_from_buffer` * test: `gguf_init_from_buffer` * fix: memory breakdown for a model loaded with `no_alloc` from a file is consistent with being loaded from a buffer * fix: use `GGML_UNUSED` Co-authored-by: Copilot * fix: remove `total_size` from `gguf_reader` * fix: file offset calculation, rename `offset` to `data_offset` Co-authored-by: Copilot * refactor: extract model loader bug fixes to another PR * feat: add `gguf_init_from_callback` * fix: always require a max expected size * fix: change `gguf_reader_callback_t`'s `output` type to `void *`, change `max_expected_size` and offsets to `uint64_t` * fix: harden against offset overflow in buffer read * fix: remove seek behavior from the callback * feat: `max_chunk_read == 0` means `SIZE_MAX` * fix: seeking in a gguf file with no tensors --------- Co-authored-by: Copilot --- ggml/include/gguf.h | 10 ++- ggml/src/gguf.cpp | 178 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 163 insertions(+), 25 deletions(-) diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 02d5f221c03..67851ba6f16 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -76,10 +76,16 @@ extern "C" { struct ggml_context ** ctx; }; + // callback to simulate or wrap a FILE pointer - read up to `len` bytes at `offset` into `output` and return the number of bytes read + typedef size_t (*gguf_reader_callback_t)(void * userdata, void * output, uint64_t offset, size_t len); + GGML_API struct gguf_context * gguf_init_empty(void); GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); + GGML_API struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params); + + // max_chunk_read is the maximum number of bytes that the GGUF code will read at once from the callback, a value of 0 means no limit + GGML_API struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params); GGML_API void gguf_free(struct gguf_context * ctx); @@ -87,7 +93,7 @@ extern "C" { GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); // padded to gguf_get_alignment if and only if the gguf_context contains at least one tensor GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index ab3cc974867..5e198618251 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -228,9 +228,18 @@ struct gguf_context { }; struct gguf_reader { - gguf_reader(FILE * file) : file(file) { - // read the remaining bytes once and update on each read - nbytes_remain = file_remain(file); + gguf_reader( + gguf_reader_callback_t callback, + void * userdata, + size_t max_chunk_read, + uint64_t data_offset = 0, + uint64_t nbytes_remain = 0) + : callback(callback), + userdata(userdata), + max_chunk_read(max_chunk_read), + data_offset(data_offset), + nbytes_remain(nbytes_remain) { + GGML_ASSERT(max_chunk_read > 0); } // helper for remaining bytes in a file @@ -257,12 +266,10 @@ struct gguf_reader { template bool read(T & dst) const { const size_t size = sizeof(dst); - if (nbytes_remain < size) { + if (size > nbytes_remain) { return false; } - const size_t nread = fread(&dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(&dst, size) == size; } template @@ -344,24 +351,71 @@ struct gguf_reader { return false; } dst.resize(static_cast(size)); - const size_t nread = fread(dst.data(), 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst.data(), static_cast(size)) == size; } bool read(void * dst, const size_t size) const { if (size > nbytes_remain) { return false; } - const size_t nread = fread(dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst, size) == size; + } + + uint64_t tell() const { + return data_offset; + } + + bool seek(uint64_t absolute_offset) const { + const uint64_t end_offset = uint64_t(data_offset) + nbytes_remain; + if (absolute_offset > end_offset) { + return false; + } + + data_offset = absolute_offset; + nbytes_remain = end_offset - absolute_offset; + + return true; } private: - FILE * file; + size_t read_raw(void * dst, size_t size) const { + if (callback == nullptr || size == 0) { + return 0; + } + + uint8_t * data = static_cast(dst); + size_t total_nread = 0; + bool reached_eof = false; - mutable uint64_t nbytes_remain; + while (total_nread < size) { + const size_t chunk_size = std::min(max_chunk_read, size - total_nread); + if (data_offset + total_nread < data_offset) { + break; + } + const size_t nread = callback(userdata, static_cast(data + total_nread), data_offset + total_nread, chunk_size); + total_nread += nread; + if (nread != chunk_size) { + reached_eof = true; + break; + } + } + + data_offset += total_nread; + GGML_ASSERT(total_nread <= nbytes_remain); + nbytes_remain -= total_nread; + + if (reached_eof) { + nbytes_remain = 0; + } + + return total_nread; + } + + gguf_reader_callback_t callback = nullptr; + void * userdata = nullptr; + size_t max_chunk_read = 0; + mutable uint64_t data_offset = 0; + mutable uint64_t nbytes_remain = 0; }; struct gguf_context * gguf_init_empty(void) { @@ -394,12 +448,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorinfo.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (n_tensors > 0 && !gr.seek(GGML_PAD(gr.tell(), ctx->alignment))) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = gguf_ftell(file); + ctx->offset = gr.tell(); // compute the total size of the data section, taking into account the alignment { @@ -844,6 +893,89 @@ struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_para return ctx; } +struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params) { + if (callback == nullptr) { + return nullptr; + } + + const struct gguf_reader gr(callback, userdata, max_chunk_read == 0 ? SIZE_MAX : max_chunk_read, 0, max_expected_size); + return gguf_init_from_reader(gr, params); +} + +struct gguf_file_reader { + FILE * file; + uint64_t offset; +}; + +static size_t gguf_file_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + gguf_file_reader & reader = *static_cast(userdata); + + if (reader.offset != offset) { + if (offset > INT64_MAX || gguf_fseek(reader.file, static_cast(offset), SEEK_SET) != 0) { + return 0; + } + + reader.offset = offset; + } + + const size_t nread = fread(static_cast(output), 1, len, reader.file); + reader.offset += nread; + return nread; +} + +struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) { + if (!file) { + return nullptr; + } + + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return nullptr; + } + + gguf_file_reader reader = { + /*.file = */ file, + /*.offset = */ static_cast(cur), + }; + const struct gguf_reader gr(gguf_file_reader_callback, &reader, SIZE_MAX, reader.offset, gguf_reader::file_remain(file)); + return gguf_init_from_reader(gr, params); +} + +struct gguf_buffer_reader { + const uint8_t * data; + size_t size; +}; + +static size_t gguf_buffer_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + const gguf_buffer_reader & reader = *static_cast(userdata); + + if (offset > reader.size || len > reader.size - offset) { + return 0; + } + + const size_t data_offset = static_cast(offset); + const size_t nread = std::min(len, reader.size - data_offset); + memcpy(static_cast(output), reader.data + data_offset, nread); + return nread; +} + +struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params) { + if (data == nullptr || size == 0) { + return nullptr; + } + + gguf_buffer_reader reader = { + /*.data = */ static_cast(data), + /*.size = */ size, + }; + const struct gguf_reader gr(gguf_buffer_reader_callback, &reader, SIZE_MAX, 0, size); + return gguf_init_from_reader(gr, params); +} + struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { FILE * file = ggml_fopen(fname, "rb"); From bcff51515008baef985289b013e2aca876cdbff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 25 May 2026 11:37:25 +0200 Subject: [PATCH 145/289] TP: fix ggml context size calculation (llama/22616) * TP: fix ggml context size calculation, memory leak * move split state cache back into the context * revert to constant ggml context size for cgraphs * increase headroom for statically allocated tensors * remove obsolete include --- ggml/src/ggml-backend-meta.cpp | 194 +++++++++++++++++++++++---------- 1 file changed, 137 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 5f9ae9c1bc5..d0d64523b4a 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -392,64 +393,100 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type( // meta backend buffer // +// Container to hold the tensor slices per simple ggml backend buffer. +struct ggml_backend_meta_simple_tensor_container { + std::vector ctxs; + std::map> simple_tensors; + + ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) { + ctxs.reserve(n_simple); + for (int i = 0; i < n_simple; i++) { + ctxs.emplace_back(ggml_init(params)); + } + } + ggml_backend_meta_simple_tensor_container() {} +}; + struct ggml_backend_meta_buffer_context { + // FIXME + // Most tensors can simply be stored statically in their own buffer. + // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source. + // If external views are simply using that buffer they will slowly deplete its memory. + // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp. + // Long-term: tie the lifetime of external views to the meta backend executing the graph instead, + // currently not possible due to graph-external operations in the backend scheduler. + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute[2]; + int stc_compute_index = 0; + int stc_compute_index_next = 0; + std::vector bufs; + + // FIXME + // The size of the split state cache is unbounded and can theoretically grow infinitely large. + // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive. static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); - std::map, std::pair> split_state_cache; - std::map< const ggml_tensor *, std::vector> simple_tensors; - - struct buffer_config { - ggml_context * ctx; - ggml_backend_buffer_t buf; - - buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} - }; - std::vector buf_configs; int debug; - ggml_backend_meta_buffer_context() { + ggml_backend_meta_buffer_context( + ggml_backend_meta_simple_tensor_container & stc_static, + ggml_backend_meta_simple_tensor_container & stc_compute_0, + ggml_backend_meta_simple_tensor_container & stc_compute_1, + const std::vector & bufs) + : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} { + this->bufs.reserve(bufs.size()); + for (ggml_backend_buffer_t buf : bufs) { + this->bufs.emplace_back(buf); + } const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; } + + ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) { + if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) { + return stc_static; + } + return stc_compute[stc_compute_index]; + } }; static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; - for (auto & [ctx, buf] : buf_ctx->buf_configs) { - ggml_backend_buffer_free(buf); - ggml_free(ctx); - } delete buf_ctx; } static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; - return buf_ctx->buf_configs.size(); + return buf_ctx->bufs.size(); } static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; - GGML_ASSERT(index < buf_ctx->buf_configs.size()); - return buf_ctx->buf_configs[index].buf; + GGML_ASSERT(index < buf_ctx->bufs.size()); + return buf_ctx->bufs[index].get(); } static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; - GGML_ASSERT(index < buf_ctx->buf_configs.size()); + GGML_ASSERT(index < buf_ctx->bufs.size()); - auto it = buf_ctx->simple_tensors.find(tensor); - if (it == buf_ctx->simple_tensors.end()) { + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor); + auto it = stc.simple_tensors.find(tensor); + if (it == stc.simple_tensors.end()) { return nullptr; } return it->second[index]; } -static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( + ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -785,7 +822,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; continue; } - src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); } @@ -1079,17 +1116,23 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co return ret; } +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync); +} + static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_UNUSED(buffer); return (void *) 0x1000000000000000; // FIXME } -static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); - ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; - const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); +static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true); GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); GGML_ASSERT(split_state.n_segments <= 16); @@ -1104,8 +1147,8 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer std::vector simple_tensors; simple_tensors.reserve(n_simple_bufs); for (size_t j = 0; j < n_simple_bufs; j++) { - ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; - ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + ggml_context * simple_ctx = stc.ctxs[j].get(); + ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get(); if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { // TODO: the following assert fails for llama-parallel even though the results are correct: @@ -1158,7 +1201,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; } else if (simple_buf != nullptr) { t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) - + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer)); } t_ij->extra = tensor->extra; for (int i = 0; i < GGML_MAX_SRC; i++) { @@ -1194,11 +1237,18 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer } } - buf_ctx->simple_tensors[tensor] = simple_tensors; + stc.simple_tensors[tensor] = simple_tensors; return GGML_STATUS_SUCCESS; } +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next; + return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor); +} + static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); GGML_ASSERT(ggml_is_contiguous(tensor)); @@ -1413,8 +1463,9 @@ static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t } static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { - const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); - for (size_t i = 0; i < n_buffers; i++) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (size_t i = 0; i < buf_ctx->bufs.size(); i++) { ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); } } @@ -1440,21 +1491,24 @@ bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); - ggml_init_params params = { - /*.mem_size =*/ 1024*1024*1024, // FIXME + const ggml_init_params params = { + /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts); - ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); size_t max_size = 0; - buf_ctx->buf_configs.reserve(n_simple_bufts); + std::vector bufs; + bufs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { - ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); - GGML_ASSERT(simple_buf != nullptr); - max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); - buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size)); + GGML_ASSERT(bufs.back() != nullptr); + max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back())); } + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); } @@ -1462,26 +1516,32 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); - ggml_init_params params = { - /*.mem_size =*/ 1024*1024*1024, // FIXME + constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals. + const ggml_init_params params_static = { + /*.mem_size =*/ ggml_get_mem_size(ctx), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; + const ggml_init_params params_compute = { + /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts); - ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); - meta_buf_ctx->buf_configs.reserve(n_simple_bufts); - for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); - } + std::vector bufs(n_simple_bufts, nullptr); + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { t->buffer = meta_buf; - ggml_backend_meta_buffer_init_tensor(meta_buf, t); + ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t); t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get(); ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. @@ -1494,15 +1554,15 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc } } if (any_nonzero_slice) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft)); } else { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - t->buffer = meta_buf_ctx->buf_configs[i].buf; + t->buffer = meta_buf_ctx->bufs[i].get(); } } - GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); - meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + GGML_ASSERT(meta_buf_ctx->bufs[i]); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get())); } return meta_buf; } @@ -1724,6 +1784,26 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } if (needs_rebuild) { + std::set used_buffers; + for (int i = 0; i < cgraph->n_leafs; i++) { + if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) { + used_buffers.emplace(cgraph->leafs[i]->buffer); + } + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) { + used_buffers.emplace(cgraph->nodes[i]->buffer); + } + } + for (ggml_backend_buffer_t buf : used_buffers) { + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context; + buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1; + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next]; + for (ggml_context_ptr & ctx : stc.ctxs) { + ggml_reset(ctx.get()); + } + stc.simple_tensors.clear(); + } size_t n_subgraphs = 0; size_t max_tmp_size = 0; @@ -1909,7 +1989,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { + const ggml_init_params params = { /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, From 1cf8e3a9039e2b9cdcb98f1c6b09359607b711de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:40:17 +0300 Subject: [PATCH 146/289] ggml : bump version to 0.13.0 (ggml/1510) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 03020888f97..f542f18b6d4 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 12) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 13) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From f14ae77f4082afaf10a98c17dde8282e12744d7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:44:07 +0300 Subject: [PATCH 147/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 2c680ce9f5d..a4f87b2b9ae 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -0a37c2167fc5b81830a32d9b1691610180ed86d6 +e705c5fed490514458bdd2eaddc43bd098fcce9b From c245b3ec23239d359ce18f3be3ee0ae92525074e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 13:05:30 +0300 Subject: [PATCH 148/289] benches : update --- scripts/bench-all-gg.txt | 237 ++++++++++++++++++++++----------------- 1 file changed, 137 insertions(+), 100 deletions(-) diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index 220bd4c98b8..1b65fc7d778 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -111,61 +111,61 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 0 | 8.57 | 1.12 | 0.27 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.17 | 1.10 | 0.28 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.16 | 1.09 | 0.28 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 8.81 | 1.12 | 0.27 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base | 1 | 0 | 15.60 | 1.61 | 0.41 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.75 | 1.54 | 0.42 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.64 | 1.54 | 0.43 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.09 | 1.55 | 0.41 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | small | 1 | 0 | 46.74 | 3.13 | 0.89 | 0.05 | f5b477ab | -| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 51.57 | 3.03 | 0.91 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 51.85 | 3.03 | 0.92 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 48.34 | 3.01 | 0.89 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | medium | 1 | 0 | 125.82 | 6.46 | 2.01 | 0.12 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 143.44 | 5.97 | 2.07 | 0.14 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 143.41 | 5.97 | 2.09 | 0.14 | f5b477ab | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 131.23 | 6.30 | 2.01 | 0.13 | f5b477ab | -| M2 ULTRA | METAL | medium-dis | 1 | 0 | 114.07 | 0.90 | 0.25 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v2 | 1 | 0 | 240.73 | 9.46 | 3.21 | 0.21 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 276.56 | 8.62 | 3.16 | 0.25 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 275.90 | 8.98 | 3.16 | 0.25 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 251.00 | 9.10 | 3.02 | 0.22 | f5b477ab | -| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 217.43 | 1.01 | 0.28 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 218.39 | 1.55 | 0.47 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 249.41 | 1.39 | 0.47 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 227.54 | 1.43 | 0.45 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | tiny | 1 | 0 | 8.10 | 1.03 | 0.25 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 8.53 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 8.67 | 1.00 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 9.32 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 0 | 15.50 | 1.51 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.63 | 1.45 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.76 | 1.44 | 0.39 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 15.73 | 1.43 | 0.38 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 0 | 45.43 | 2.93 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 49.78 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 50.22 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 47.08 | 2.78 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 0 | 125.19 | 6.10 | 1.88 | 0.12 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 142.49 | 5.59 | 1.90 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 142.63 | 5.68 | 1.92 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 130.98 | 5.83 | 1.87 | 0.13 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 113.95 | 0.88 | 0.24 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 239.27 | 8.97 | 2.92 | 0.21 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 275.07 | 8.56 | 2.92 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 274.28 | 8.62 | 2.93 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 248.90 | 8.32 | 2.81 | 0.22 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 214.26 | 0.97 | 0.27 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 222.47 | 1.49 | 0.45 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 250.56 | 1.35 | 0.45 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 228.57 | 1.33 | 0.43 | 0.03 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 1 | 6.06 | 0.96 | 0.22 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.51 | 0.93 | 0.22 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.47 | 0.93 | 0.23 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.16 | 0.94 | 0.21 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base | 1 | 1 | 10.63 | 1.37 | 0.32 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.75 | 1.27 | 0.33 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 11.73 | 1.25 | 0.33 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.17 | 1.28 | 0.32 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | small | 1 | 1 | 31.74 | 2.55 | 0.67 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.21 | 2.47 | 0.69 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.22 | 2.47 | 0.70 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.73 | 2.45 | 0.66 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | medium | 1 | 1 | 86.94 | 5.21 | 1.49 | 0.09 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 104.31 | 4.93 | 1.51 | 0.10 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 104.09 | 4.98 | 1.51 | 0.10 | f5b477ab | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 92.13 | 5.06 | 1.45 | 0.09 | f5b477ab | -| M2 ULTRA | METAL | medium-dis | 1 | 1 | 76.67 | 0.81 | 0.20 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | large-v2 | 1 | 1 | 167.66 | 7.56 | 2.25 | 0.16 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 203.09 | 7.13 | 2.29 | 0.20 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 202.53 | 7.12 | 2.29 | 0.20 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 177.48 | 6.94 | 2.18 | 0.17 | f5b477ab | -| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 145.61 | 0.91 | 0.23 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 146.95 | 1.33 | 0.36 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 178.57 | 1.17 | 0.36 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 156.19 | 1.21 | 0.34 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | tiny | 1 | 1 | 6.03 | 0.86 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.46 | 0.84 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.46 | 0.85 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.14 | 0.88 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 1 | 10.87 | 1.24 | 0.31 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.98 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 12.07 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.13 | 1.19 | 0.30 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 1 | 31.46 | 2.37 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.16 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.57 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.94 | 2.27 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 1 | 89.86 | 4.92 | 1.41 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 107.12 | 4.72 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 107.00 | 4.70 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 94.93 | 4.56 | 1.37 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 79.66 | 0.78 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 170.06 | 7.13 | 2.15 | 0.16 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 205.16 | 6.80 | 2.18 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 204.22 | 6.69 | 2.16 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 179.78 | 6.35 | 2.13 | 0.18 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 148.11 | 0.89 | 0.22 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 149.23 | 1.29 | 0.34 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 180.77 | 1.13 | 0.35 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 158.66 | 1.10 | 0.33 | 0.03 | f14ae77f | ## M4 Max @@ -233,20 +233,6 @@ make -j && ./scripts/bench-all.sh 1 1 0 make -j && ./scripts/bench-all.sh 1 1 1 -| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.71 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | tiny-q8_0 | 1 | 1 | 8.47 | 0.67 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | base | 1 | 1 | 15.47 | 1.12 | 0.26 | 0.02 | 47fcd7da | -| M4 Max | METAL | base-q8_0 | 1 | 1 | 15.70 | 1.05 | 0.27 | 0.02 | 47fcd7da | -| M4 Max | METAL | small | 1 | 1 | 49.82 | 2.37 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | small-q8_0 | 1 | 1 | 51.76 | 1.99 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | medium | 1 | 1 | 147.76 | 5.52 | 1.27 | 0.12 | 47fcd7da | -| M4 Max | METAL | medium-q8_0 | 1 | 1 | 153.98 | 4.59 | 1.24 | 0.13 | 47fcd7da | -| M4 Max | METAL | large-v2 | 1 | 1 | 282.89 | 9.06 | 2.11 | 0.22 | 47fcd7da | -| M4 Max | METAL | large-v2-q8_0 | 1 | 1 | 296.43 | 7.44 | 2.09 | 0.23 | 47fcd7da | -| M4 Max | METAL | large-v3-turbo | 1 | 1 | 249.91 | 1.65 | 0.38 | 0.04 | 47fcd7da | - | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.72 | 0.16 | 0.01 | 47af2fb7 | @@ -262,41 +248,77 @@ make -j && ./scripts/bench-all.sh 1 1 1 | M4 Max | METAL | large-v3-turbo | 1 | 1 | 256.23 | 1.61 | 0.38 | 0.04 | 47af2fb7 | +## M5 Max + +make -j && ./scripts/bench-all.sh 1 1 0 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 0 | 4.88 | 0.65 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 0 | 4.84 | 0.63 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 0 | 8.95 | 1.02 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 0 | 9.12 | 0.94 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 0 | 25.61 | 2.15 | 0.52 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 0 | 25.77 | 1.93 | 0.50 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 0 | 73.96 | 4.61 | 1.16 | 0.08 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 0 | 74.89 | 3.94 | 1.12 | 0.08 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 0 | 132.06 | 6.91 | 1.86 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 0 | 132.56 | 6.00 | 1.76 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 0 | 119.34 | 1.30 | 0.32 | 0.02 | f14ae77f | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 1 | 4.31 | 0.59 | 0.13 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 1 | 4.51 | 0.55 | 0.12 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 1 | 7.77 | 0.91 | 0.20 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 1 | 7.67 | 0.78 | 0.19 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 1 | 20.90 | 1.76 | 0.40 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 1 | 21.32 | 1.62 | 0.38 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 1 | 60.40 | 3.98 | 0.89 | 0.07 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 1 | 60.72 | 3.35 | 0.86 | 0.07 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 1 | 110.57 | 6.06 | 1.41 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 1 | 110.92 | 5.00 | 1.31 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 1 | 98.36 | 1.19 | 0.27 | 0.02 | f14ae77f | + + # RTX 5090 make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 0 | 2.20 | 0.51 | 0.13 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.35 | 0.52 | 0.14 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base | 1 | 0 | 3.97 | 0.77 | 0.20 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.20 | 0.73 | 0.20 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small | 1 | 0 | 11.87 | 1.48 | 0.40 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.40 | 1.59 | 0.42 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | medium | 1 | 0 | 32.63 | 3.11 | 0.82 | 0.04 | f5b477ab | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 31.80 | 3.23 | 0.84 | 0.05 | f5b477ab | -| RTX 5090 | CUDA | large-v2 | 1 | 0 | 52.22 | 4.66 | 1.18 | 0.06 | f5b477ab | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 51.11 | 4.37 | 1.15 | 0.07 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 48.72 | 0.70 | 0.18 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 47.81 | 0.64 | 0.16 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | tiny | 1 | 0 | 2.17 | 0.38 | 0.10 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.31 | 0.37 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 0 | 3.94 | 0.56 | 0.17 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.13 | 0.53 | 0.14 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 0 | 12.06 | 1.09 | 0.34 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.50 | 1.11 | 0.30 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 0 | 33.08 | 2.38 | 0.70 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 32.57 | 2.26 | 0.62 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 0 | 54.27 | 3.68 | 1.03 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 53.11 | 3.22 | 0.89 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 50.56 | 0.58 | 0.15 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 49.39 | 0.49 | 0.13 | 0.01 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 1 | 1.37 | 0.44 | 0.11 | 0.00 | f5b477ab | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.48 | 0.44 | 0.12 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base | 1 | 1 | 2.34 | 0.66 | 0.16 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.51 | 0.62 | 0.17 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small | 1 | 1 | 5.53 | 1.23 | 0.32 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.88 | 1.35 | 0.33 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | medium | 1 | 1 | 15.09 | 2.55 | 0.65 | 0.03 | f5b477ab | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.06 | 2.72 | 0.67 | 0.03 | f5b477ab | -| RTX 5090 | CUDA | large-v2 | 1 | 1 | 23.24 | 3.94 | 0.97 | 0.04 | f5b477ab | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 22.00 | 3.68 | 0.93 | 0.05 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 19.81 | 0.62 | 0.15 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 18.62 | 0.56 | 0.14 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | tiny | 1 | 1 | 1.29 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.45 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 1 | 2.15 | 0.44 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.27 | 0.43 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 1 | 5.54 | 0.83 | 0.26 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.95 | 0.84 | 0.22 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 1 | 15.43 | 1.81 | 0.53 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.71 | 1.66 | 0.46 | 0.03 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 1 | 24.73 | 2.92 | 0.81 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 23.35 | 2.43 | 0.67 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 21.36 | 0.49 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 20.07 | 0.39 | 0.10 | 0.01 | f14ae77f | # DGX Spark @@ -318,22 +340,37 @@ make -j && ./scripts/bench-all.sh 1 1 0 | DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 264.90 | 2.03 | 0.37 | 0.03 | f5b477ab | | DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 253.56 | 1.48 | 0.27 | 0.03 | f5b477ab | +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| DGX Spk. | CUDA | tiny | 1 | 0 | 9.79 | 0.65 | 0.14 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 8.97 | 0.56 | 0.12 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 0 | 18.58 | 1.04 | 0.22 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 17.36 | 0.88 | 0.18 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 0 | 56.78 | 2.33 | 0.51 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 55.47 | 1.99 | 0.43 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 0 | 158.21 | 5.71 | 1.23 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 151.17 | 4.54 | 0.97 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 0 | 269.59 | 10.48 | 2.13 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 262.82 | 7.43 | 1.61 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 263.91 | 1.80 | 0.37 | 0.03 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 252.89 | 1.23 | 0.26 | 0.03 | f14ae77f | + make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 1 | 2.63 | 0.76 | 0.13 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.46 | 0.73 | 0.11 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | base | 1 | 1 | 4.96 | 1.24 | 0.20 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.23 | 1.08 | 0.17 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | small | 1 | 1 | 16.26 | 2.73 | 0.47 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 14.94 | 2.38 | 0.39 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | medium | 1 | 1 | 51.81 | 6.94 | 1.22 | 0.05 | f5b477ab | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 41.51 | 5.44 | 0.93 | 0.05 | f5b477ab | -| DGX Spk. | CUDA | large-v2 | 1 | 1 | 98.54 | 11.53 | 2.05 | 0.08 | f5b477ab | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 91.61 | 8.49 | 1.55 | 0.08 | f5b477ab | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 87.20 | 1.94 | 0.36 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 80.28 | 1.38 | 0.26 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | tiny | 1 | 1 | 2.72 | 0.56 | 0.13 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.55 | 0.47 | 0.11 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 1 | 5.08 | 0.90 | 0.20 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.38 | 0.72 | 0.16 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 1 | 16.95 | 2.00 | 0.47 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 15.67 | 1.67 | 0.39 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 1 | 53.12 | 5.10 | 1.24 | 0.06 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 43.64 | 3.87 | 0.91 | 0.05 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 1 | 102.15 | 9.58 | 2.02 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 93.86 | 6.54 | 1.49 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 90.29 | 1.69 | 0.36 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 82.79 | 1.13 | 0.25 | 0.01 | f14ae77f | # V100 From e0fd1f6787a5bd4a4957dd97c5b64df882ee7b0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 13:06:33 +0300 Subject: [PATCH 149/289] release : v1.8.5 --- CMakeLists.txt | 2 +- bindings/javascript/package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0f74041321..2200673d0a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.4) +project("whisper.cpp" VERSION 1.8.5) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 074dfdda307..caf12b6dd2d 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.4", + "version": "1.8.5", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 27101c01dcac1676e2b6422256233cd0f1f9ae28 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Mon, 25 May 2026 23:23:41 -0500 Subject: [PATCH 150/289] cli : merge tokens split across UTF-8 boundaries in JSON output (#3751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cli : merge tokens split across UTF-8 boundaries in JSON output When a multi-byte UTF-8 codepoint (most commonly a CJK character, 3 bytes) is split across multiple whisper tokens, the -ojf/--output-json-full writer emitted each token's partial bytes as its own JSON string, producing invalid UTF-8 that chokes downstream parsers. Merge adjacent tokens in output_json whenever the accumulated text still ends on an incomplete UTF-8 sequence. The merged entry keeps the first token's id/p/t_dtw and extends t1 to the last absorbed token, which matches how segment text is assembled elsewhere. Refs #1798 * fix: address review — add braces for consistency, use full issue URL - Add braces to if/else chain for codebase consistency - Use full URL for issue #1798 reference Review: @danbev --------- Co-authored-by: texasich Co-authored-by: texasich --- examples/cli/cli.cpp | 80 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 4e84c1b2750..55cd71b4e55 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -31,6 +31,39 @@ static void replace_all(std::string & s, const std::string & search, const std:: } } +// Returns the number of trailing continuation bytes still needed for `s` to end +// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a +// complete codepoint (or if the tail looks malformed and we should stop merging). +// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character +// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798. +static int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + // walk back past continuation bytes (10xxxxxx) + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + // all continuation bytes, or empty — nothing we can do + return 0; + } + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; // ASCII + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; // malformed lead, give up + } + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -738,18 +771,47 @@ static void output_json( if (full) { start_arr("tokens"); const int n = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n; ++j) { - auto token = whisper_full_get_token_data(ctx, i, j); + + // Merge adjacent tokens whose bytes together form a + // single UTF-8 codepoint. Multi-byte characters (CJK + // in particular) can end up split across whisper + // tokens, which used to produce invalid UTF-8 in the + // JSON string. Refs issue #1798. + struct merged_token { + std::string text; + whisper_token_data data; + int64_t t1; + }; + std::vector merged; + merged.reserve(n); + for (int j = 0; j < n; ) { + auto tok = whisper_full_get_token_data(ctx, i, j); + merged_token m{ whisper_token_to_str(ctx, tok.id), tok, tok.t1 }; + ++j; + while (j < n && utf8_trailing_bytes_needed(m.text) > 0) { + auto tok_next = whisper_full_get_token_data(ctx, i, j); + m.text += whisper_token_to_str(ctx, tok_next.id); + if (tok_next.t1 > -1) { + m.t1 = tok_next.t1; + } + ++j; + } + merged.push_back(std::move(m)); + } + + const int nm = (int) merged.size(); + for (int j = 0; j < nm; ++j) { + const auto & mt = merged[j]; start_obj(nullptr); - value_s("text", whisper_token_to_str(ctx, token.id), false); - if(token.t0 > -1 && token.t1 > -1) { + value_s("text", mt.text.c_str(), false); + if (mt.data.t0 > -1 && mt.t1 > -1) { // If we have per-token timestamps, write them out - times_o(token.t0, token.t1, false); + times_o(mt.data.t0, mt.t1, false); } - value_i("id", token.id, false); - value_f("p", token.p, false); - value_f("t_dtw", token.t_dtw, true); - end_obj(j == (n - 1)); + value_i("id", mt.data.id, false); + value_f("p", mt.data.p, false); + value_f("t_dtw", mt.data.t_dtw, true); + end_obj(j == (nm - 1)); } end_arr(!params.diarize && !params.tinydiarize); } From ee540bf0be55d2a5176872adfb519e70f4fe0e9a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 27 May 2026 06:22:38 +0200 Subject: [PATCH 151/289] docs : add AGENTS.md and CONTRIBUTING.md [no ci] (#3826) * docs : add AGENTS.md and CONTRIBUTING.md [no ci] This commit add AGENTS.md and CONTRIBUTING.md which are based on the same files in llama.cpp. They have been modified slightly to fit with whisper.cpp. The motivation for this is to clarify the contribution policy in whisper.cpp so that contributers can have a better understanding of the expectations and requirements for contributing to the project. --- AGENTS.md | 102 +++++++++++++++++++++++++++ CONTRIBUTING.md | 176 +++++++++++++++++++++++++++++++++++++++++++++++ media/matmul.png | Bin 0 -> 265705 bytes 3 files changed, 278 insertions(+) create mode 100644 AGENTS.md create mode 100644 CONTRIBUTING.md create mode 100644 media/matmul.png diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..f34f3249977 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,102 @@ +# Instructions for whisper.cpp + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Read more: [CONTRIBUTING.md](CONTRIBUTING.md) + +AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below). + +--- + +## Guidelines for Contributors Using AI + +whisper.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers. + +Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly. + +**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution. + +Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it. + +This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions. + +--- + +## Guidelines for Contributors + +Contributors are expected to: + +1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes. + +2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback. + +3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected. + +4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed. + +Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main whisper.cpp repository. **Private forks are exempt.** + +### Permitted AI Usage + +AI tools may be used responsibly for: + +- **Learning and exploration**: Understanding codebase structure, techniques, and documentation +- **Code review assistance**: Obtaining suggestions on human-written code +- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns +- **Documentation drafts**: For components the contributor already understands thoroughly +- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work + +AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance. + +**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research. + +### Prohibited AI Usage + +The following will result in immediate PR closure: + +- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time +- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review +- **Implementing features without understanding the codebase** - particularly new model support or architectural changes +- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans + +--- + +## Guidelines for AI Coding Agents + +AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project. + +### Considerations for Maintainer Workload + +Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify: + +- The contributor genuinely understands the proposed changes +- The change addresses a documented need (check existing issues) +- The PR is appropriately scoped and follows project conventions +- The contributor can independently defend and maintain the work + +### Before Proceeding with Code Changes + +When a user requests implementation without demonstrating understanding: + +1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase. +2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach. +3. **Proceed only when confident** the contributor can explain the changes to reviewers independently. + +For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy. + +### Prohibited Actions + +- Writing PR descriptions, commit messages, or responses to reviewers +- Committing or pushing without explicit human approval for each action +- Implementing features the contributor does not understand +- Generating changes too extensive for the contributor to fully review + +When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain. + +### Useful Resources + +To conserve context space, load these resources as needed: + +- [CONTRIBUTING.md](CONTRIBUTING.md) +- [Existing issues](https://github.com/ggml-org/whisper.cpp/issues) and [Existing PRs](https://github.com/ggml-org/whisper.cpp/pulls) - always search here first diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..c301604f1de --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,176 @@ +# Contributors + +The project differentiates between 3 levels of contributors: + +- Contributors: people who have contributed before (no special privileges) +- Collaborators (Triage): people with significant contributions, who may be responsible for some parts of the code, and are expected to maintain and review contributions for the code they own +- Maintainers: responsible for reviewing and merging PRs, after approval from the code owners + +# AI Usage Policy + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Repeated violations of this policy may result in your account being permanently banned from contributing to the project. +> +> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file. + +Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations). + +If AI is used to generate any portion of the code, contributors must adhere to the following requirements: + +1. Explicitly disclose the manner in which AI was employed. +2. Perform a comprehensive manual review prior to submitting the pull request. +3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). + +For more info, please refer to the [AGENTS.md](AGENTS.md) file. + +# Pull requests (for contributors & collaborators) + +Before submitting your PR: +- Search for existing PRs to prevent duplicating efforts +- whisper.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier +- Test your changes: + - Execute [the full CI locally on your machine](ci/README.md) before publishing +- Create separate PRs for each feature or fix: + - Avoid combining unrelated changes in a single PR + - For intricate features, consider opening a feature request first to discuss and align expectations +- If you are a new contributor + - Limit your open PRs to 1 + - Do not submit trivial fixes (e.g. typos, formatting changes) + +After submitting your PR: +- Expect requests for modifications to ensure the code meets whisper.cpp's standards for quality and long-term maintainability +- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR +- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention + +# Pull requests (for maintainers) + +- Squash-merge PRs +- Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` +- Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules +- Let other maintainers merge their own PRs +- When merging a PR, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) + +Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions: +- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone. +- The pull request duplicates an existing one. +- The contributor fails to adhere to this contributing guide or the AI policy. + +# Coding guidelines + +- Avoid adding third-party dependencies, extra files, extra headers, etc. +- Always consider cross-compatibility with other operating systems and architectures +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit +- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) +- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) + +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `_`, with `` being `_` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `` can be omitted + - The `` can be omitted if not necessary + - The `_context` suffix of the `` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Code maintenance + +- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces. + _(NOTE: for legacy reasons, existing code is not required to follow this guideline)_ + +- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md) + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggml-org/whisper.cpp/projects diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 0000000000000000000000000000000000000000..786a20492c02b4ee83fcb2a2bcefa0699ee7a55c GIT binary patch literal 265705 zcmeFZXHZqy7B0GNTWu4x0TmQ5f{K7BNCs^Y6crQ+l9ebRAUOxyZV@Dy06|iNNRTW! z8VE|1oI%MM$r|{X@?p z*t+)LFYxyZ@do<;e);+0$MXMvbyj@4CAsaNe@;D1IQ{Scr%;|e{VQ&vB?)q<+Yotx+_`fKA3%Uq7*$0a+tso%aP!1{k~i6zP4khY#*y^*Y4CvN^;;=|rN zH{TcL<=N8G!Y%c44R-u5r%^^Ka|3?G+5eQqK+PT=ec9Ly&vcup%h&3~Otp**5#kO1V=^_o^ zH|O;JH?KOH+m|iP`@Y`pox1d#rQi~4plpxQZk2oIgeJe9&r)GKal-w>gAGOQ^rtN| zjGUNmjUFKl^}Sle>(__d>?g`ps_EV=^ZxCJ zBs|?uR|uK;4o>&3nO1A;XCl`g-Hg}jU|yLmDz8^mR2(ZkUGaj!+%vL8M?c^9ey>iM z|FC^vU?5{eu%qn}7V402&Kzx*wlG2oW{(%Mg|81hRNNk1@#TwauETmT{n~N_CBrqU-&_?` zpXsYSL?iU${U)i^8nSb7Y7b`SH`Ur_+xJg+U7UKJ$Q!9Zr5>dm+o`7R_fMTtOgaly z)dXwlZ;i67u4AwIrxK67yq?O&IVq)07as2_2;es>&TdaLZf`KoTawI4GiuF5SuT{< z|7Z1v8xjv~+O)|sccdojhKgmQwszg_ojb)$OicV{O6F?{9zBv4T{8ErCo5;PdN*=7 z%iP9oLgVA(2M=~+4#u%fbdDV?7Omy~sB@gGEah+_ijZ$0GAtC=xNRGk{p^oOg?Kf; zDZ8d%X1;s3{AQLdoO0E2Y`YyZ4^b@C<&S$yhTHvy-;6XupTN@788#&!o6sCNc(7b_ zWpTDdeOabw>F>?E3l(z5y=6uXTQhk?mS;O}rruRNp_p_yE^{FI&exE0auU?CbQXA< zp3e71I+&?7%gSqLD5;`j=<7e{Puxrhsj*{MPIZ@iEAKtG_;Bl9jY4<&M6EB+8nSK7 zGFUIJy*5Lg;|Uc_iJB7o%)g#Xv>a}d8MPXyIVLjK^;{{4pe71MzixhNuroL@F(xMF z4o-ONbTF^6+CG!~ZQ5;lKR#~mq;;A-&o1@rSr)bh1zr5YElysvuu~7pw{2UyeoJJ5 z9wo$5_{aOTCgWdNM;HB1>q=!xkVd0_4u6vGR30RY%gM<}j#fjm-i_GulW0k%qs_q^ zzNvS=wj&oY;EC16kyN>U$bk25!%0&!GtHHy8LtO40!Ia`M)tL4T09ewC$@?YT{L z?8c{4Zhv}cQhY2zR+*DrDCJ_>24?${I;AfYo2Nd0{HU~9aQE)rMd>D;D9#+anLD`B z*wSClwmu1@_QzY&WN&BZDx8?m>D=GOWgf?oyV1?_*v(C&_{A~pHc{+F817zfWEFJz z-5PF6F&T_a^RXR$E-ez3!cSWD?BYa4N%rWC-D_s8sHJ-tFY}HM%_z)qXiT&Os?l1tiNDnWYpEwl~2mKyV!f|!v^O3$DFx)ICEz;cJAKo zn@0b=&UmLK1N+HJiOLj%df_;nbdA2R|0Le24Bb?gWj)sMQqaoCZzgbn ziml(bbu{V*w{}J_`n-6NKh*fnz}^RYV%mNqL6r8a9u<#ACSjbh#xZi=`O0KK3LV?$ zfkhIoGiuA`e@M$5l}$dvhHoQy zBcxcTJissgOux7`HnFqkm9doU!qB^lLR9&Vife5;`slf#C~nt!YEd{}nNnYwQs**m z%N9$aYaEP^!D#le8GUrf^AQcBK+WtHN{ukJ;A9)~U3cuLP0_DCjxqYwx^>jPSI~9f zjA7Hx`0R1X&?eNk7*4$3jF;&0ET;G6PAFiQ;I=J z_J7x*jJJK?US6XsD(W|H-h5E~!YSJ9uO%BwF?EET^}Z4;A}q2rSuK|0-57ZnGe0zw zj!ph>wB6EFeaU~8Gt+3i>&sY?=yKv86(=x$<C2Q>Ha40R((B4vF9Gg}S+FU@U6ED5%0Ax3-tkh?^HMU+>V1bkd-iPh1Fhrl zQL@n=E6V(N-iqrS=iuO&*}-5pGZHwvM4?RoT-!@k2@{nP%#3L0F+b)niE0=(ryAPD z96Ns8?bwYsFKnFNJFmGOua=Ylj*1k{zOh*xE9zXDX^*dYM^0UjUhd>b)@WgGu#;Mk zLvZGAXKhTYgvn)C)T53yI`Umz^*7M$yBWpwcsHk%tZ#5|UXlO*+#wP;Ge$B8J6g;t zTIc59sxRp?M3A?NK7alC^%m0-c5lbrHy_Kp9R@VSy4%lFZnf_xgZgYZhBQa2QLBGs z>Zx-bz*~j>JOc)AVW=xA#f(CY>~wLU|M9BfK6W*%-%lb+dwS3b+bz@;G; zqmp5oHi!x7ioF>cn?paZ6y0E%hEPkR+jxA0rOIfP%*j6EQWY@Y#*XMzv zI2tbnEcE?m7N&>y0y)(J+UB@WH{08Xd#D9)44DA$Y?r>Eo6$b3loW)#okq z?BZ10huU@n5can9FUVs6lFOc|Q?A6Vk$KnE+glXwtHwYo{Ob_)!mhn?d;GYhB>BSC&e^&S-U>$yk@Osmb{G2_F%_(y(_QIu)n$Dw|(8ZbzCML`T0xA zsdwq6Mdq}k3?#(34eCzdpaom!Bn_ptJIbl8!V3Lj_OzDmU`Or4^WQ!^2$K#L$x?ai z@jMW}tJu9Xhj}gAZ$p_DArh4L$N_9)q_X4myWj#UtN{CtjLt8~=zdhgsZL?}NHm-@hZVz-=8K;`n8uX>4 zq`5PKt?sP?r<ziX%-%i3p{!I5ATIq{75-pr~4ied+U8p zi(_(Xqz3>kW`bDd);zdilsnfWLx89Rla1CDf9|Lj3#*}qZKRxl78ObYDA(jxDUx$$ zceS*r2d(I0g1oK6uz~vH_Qf9FL{MJqZf}KOSGBSI2bXbx3Yv}ca+;^m!V0~d7CTP` z4`S|wPC?#K>~ExmWKv<{wF6{QqrofnVE-Ukpk5#Ib6s2C^XLEJ?u?t)Fo{6ck zBa97;`(o~{FNf;(soyQv?|=RqQ9r0J5u2j+Rw3S#Ek64cVEAq}`IyL*pK3tVNJeka z&1~y2qk3tfvHbLJe{L>zIHT8)k&uue^KQ-AtDhy&mCR|eYJi)P;h`3Tb%d-4*Euc? zs(Xnz%oFMp1G0s)7#10M6cS0C5zv#pnC!1}XL|g@%S&<3==Qe|#r6Nz4EPR;bOo(}Q60P+LmKU0R_ZXBYdW zotoxsQ74taAFC3|(3f?)RepC`o(Q(ea-8e-EsDGzwmr5oNxNhgJnZmOQk;HITG^$=-a@EJEQvN69)fq#E znN}!ZJ zP9;^Z;)74%E*dzwtN7si)P}7LJ0PambWpw)}jVc3i z4GOOb;a=q;tGI~V_zo$)e?AWCV(Rnv20Tx91jV7F5+NXwGQ+giU0`ETQ@K(&1vjXt z;K{|R1hbXpr9-T&5i-J_BJ+LuDb*3u@v^H&J(7)Z)!E>cx!~l8)3M&KB@8j#5@IYS-)&Y9mH?lpP@;4x~Mqnc%D0?AU| zwiV@9ROqp{B(@RsxHL#uFn2m7UM`cCBD-N#Zadl^RSrnNb?d#eM7QWMyc`kl={7Ep2FBGly;i0J{}*; zJ*0w->sCWsq;A{`1JrpUhL>21+*ZMgKtW!UVK6bpa5d_yJaP5mP$hHf=v14xf*L)T z?>%ojzD9cIOff2MH}#-D&2n9}v(SUg6S)DoYuBEnqvw0Y%B-^GRVEhY@`M|to^)&d z{-@kSLLlu8nHB~qlfOC_Q`N^1l}n$>S%d9O@pB1P4&X5KDCz4R^Qgb9svPRo?BC(R zD0n6a^)552>M&JXYZ`4PCZ?vQw$J+8eY+So4h~O(b}ma(B3A{|;$(E@{54%`)~=V0 zsu!MVd7e(jm`THJ7rCf^;TpPG;Y{AjlFTR%xx~!=H~PcAljgl-`WvWIA|nGa?2&hN7PN9{?@Tzc&Eae6}nd_(!8WfIqw~ ztHi~5^q*arw$~G>mk(N`UYZ5`8YaPjgpsnnvy&5c%(`2{y>;Hl5Uz-QbG(QDG?nT)&kbUb8S5`u`TrYEE97bc@F6`^=O_#UOI@$r9L|SME^*yvQ z0jzk%-S=GYSpt_8tB+GPcx@#mCDqY{-K+h)h4MsfJTDPH+n}30*Yj$u<1r^o9J8>k zIl)l5b7jVb2VC}TUl`4EA|R|TPIWJt)ffl`uP;3(oGKC8%cim_;38{?0-d3cIima0 zs7;4LZ}8>W5t9$=c2o679vd4Qp1PG!w4zs37V>vU=L3Qn+}%YvD*a!*u*w+1I2UmT zPU~DD93>HpZSJ)b*ru!ec?k7oZy6`xv|2CE^>(s~&RvSRnHbvA1ZrRO`RSp~zjjDH zg2YEOl<);2vhHG6+KCC|g?X3{DTy9X@6zJpOzuohIuZ5~KG8fz%MrJ<17xQ*ljUFiAXy)K zJ2_a-Nfg#T0P?P`S#>$;g;FTFU5r`)(+((LU|AK{jcux z|HksJuFwCr0nVgao%(MUfb##f8UFvjfx0Ra)zCpcPlQ=impo32?_AuCzN ziB5Bp_T}t}azVS8LEu3Mq#?Uleu>N^A}%_zy3y4K>$iqtoTqS;N;%ug>o7lQ-5Y2{ zRP`~Kq=lgu9)li4w&?l;DhCQ!xP^HN9JierVI-;ydPrSWzMmm_&=|B9Z{-_A?gP{` z&^zcHt(zOEs8hTT zVFr74tD0XY{GUi8`ik#3l#;YAB^%Tmd0CYBa<)gT1UP4EgJ;;;?}#L9O{2g~$^W+y z4adna<^pKTgMsmI)0PM|VqnE;Ps_d2DQn6Q&AavCPrMv_CzG)P>PSAw%bTWD7xQ>| zIxVjN&aZ;eRYylhsA)+)WR-{^9GAB|Z)#93Fc@RqAMUHKcQEA5Dnq)MMm%{k2JWD3 z5|-NZCpffOPpRwoPc9PR7o7ud<2@n3$~8ApPtP^qtJu48K&ZZ|@XOE$g==@2dp>%3 zN?URZqX0Mf@oq>Bc6ZydGn(qw!w(`HIiX@61GR$?AJFV0m|{;!GxeBi!v7K*wGx^x z5!DQ0RR4h{JWyuU_}uMDWT!uuC@(z1J2GmzNd@uBke+%ByPLKp+q<2frawh|;_-)cr@> z4s}#r=Je?3=twR_H5O1`NSYp4LdE*~--*Bs=3d(gkG{Wl1JUBEg;(9Tnn>oQJ_+I~ zy7S4+xv~a5Vv9=;m6!Ix)r=F3szFx>BD6`Pwr>F_ca?Cl)M(C(nu`5c!5P^#lt_zu zRV!5@#tG&Ts_@3{Ffr%=WfNwn;++CVxGafym5y^%Eez_gadr4cZl%5mgf=+=E+9U( z@rj8roZjm=1ha4pG%yAe9TujKmrY}{3t;%{saG+(>Rq>n;_ODo<$nOET?WzGh*EJp z_UZHIJK$G8ef4TDd^hUU#sYauyXNBhs2PAA(n5)ngOEPnQEzQ9_GaS2a>3E;2AuW3 zgNL#G_MLm%A5}{UvS-5VzUt527^{P88Edl}E$I64JS-$+KZ{hr&F|0sq0BHBDe4B& zxiHz@`CFHoxp2P1Z8P4L4n>1)Q3d;_0WH-8jYc9^#32-V4Y|G60UvhdU8A`RM$n6n zL1^B4TBkJfS}?S%{3E_9+ws*=3kM42)Ol)v!r9*v`cNSP2gfH%A%;#7$U_<#Av1g^ zJISpA5@KEO`rqGQ+nL&I^myYAmVCqlh?)Et)Qe>#&1Rx!&zZY7|JM0=ktwK#bYvrl z&E^{X1v-U6e9jp%;W_gbO;yawc!<4p>cMAXYZ*HXDvv=jZXX*fVlF}$qLU;Oh($8O z59pu1vb?aPel6wNLXr%fuqmL^z!Os854X@iOV7m-9wT6kNHN`Qgea5H0+BSurbtcz zo$mq^_UIPDYGOU(prsXM!mW`2L)E{HMo18k`#%1p7Jtl!f_TTNx9C@B3H!}j?tR*7 zf+}K(8^y1Y8YLHfe8v$0lup6+IjN2Z*I@i=jCB;CAe6h?{CbR?7L)_5+HAV(h+pd7c1SN+?NH{|aBIM)ldtzht#{~$>L5S~zo2@@2#+#eAv zCA1o|(+Uk=be9khy_g zH#=XTHA-y ze?_()bjo@lDk6GGtk=n^(}PzW^|!rN1bPmSz-6&~oLI||?rTpipRmwN=5z*oUjxu#LH}AnrFj%2Y$uLJrt$6l zeKHF5mzbU(hr=_|X2%%>Ez=r#2~83KAvT3Rj$SBky%L1A4}Qy*eS9aCQt$ft*ceoU zVRZbxTkTO=n=Kn+0Gc?VS@uGA2%Lg4Q3-iq5BHsGTNt>HjulyoixG6=rJ(M>(}tt{ zQrOOt;B?<9qCXOs94_1iTvbQRyK*;dg~H&>O@FwU`IwQ6VE0tEoNY6q0Z}EL_3 z#%?Ga#Mu>t9v%w#NW30mw2~kLkBIKomm&_f#26xhHu8R@*zUAjc1A`c9HQ0I!c#SU zIJW*QZMEpK8c>YKU~(5C3=xW)Owzk|VZ^V;#-?{YyWG)a5Px;%CBcVB*r9e1-B~Op zt<_SQbQK<<>tckRKD*4}s*dQLU5vwUO1(zc=clz*X(Y!5$7&DsxhsCBj!>mEk-K7I zFv3ESFCzlE_-I=Wb*zEbXDjK1m{MCfQg0j!y(_t+h}{&COC>8wP}ruN2;%V;kkpx3 zy9HU0zs{SPnH>@pb?SATK~{i~2y7{iUYD?Q?64O*;HtU7vN%b$30EpO+S}Ddy#mZ; z6J!;rEAM`~{1JrA2<^6}X9y$=)I^N|iclpATs6sYsOpoDRv)qtq#IHvA`WF3w$8PV z6D0&6*czs|CY*3C*c>-8IyrdQm6C~?PykG_2k{3__*bkTW^ae>PL#lAW}-TeC_ZXN z16Z5va3b^@)d2|dJ2ARlKhX%D@Dvz`B$zSa-p7q5Fcc-A^>|b$w&Fx=Bjtg_dBp!+uZ3K+^L?wtpcpBF=P`9Nm&sO;}K`O#1{~g^ECX2 zjdG^y`jBt*kN7J|J3M~^ps8dL>XIw^k&~AN7=st-h&Xccn{PN|9H@zwd#ifpe(^Y2 zCO`jjZ9MMZUQ7|oYwg?@W zfUp!28FAB6j}uN{hv;&rXa^cRC-2>PQY=5df9gaPM1M7|gQ$K1LOFRGlP0~0BaS1% zr1ur=GrN+(gH~3%@+DouZx|;ofewgJ3u_y(pYdX z(?WXka1#?6^OODEb;Q7t6S<{>v*=2r=?5qXDJiDKztdBYLj;b)snMB%S~dg#M2GB! zk!Jxb5$u3+9`#i(f}E#+&bt(XQ5SEW%SFN9jjd#@^X}y>O;d3h@+>M*=P!}Ix961B z#S9*Rk{OQCSQ6JA!9t|tw3RHKK5>-Nbwad;rJcUe{FQ=eNM}+JY z=JoGzOq(1RkAeKR3QfG#IihYrwrfsGc~XuIT>?;Wo!7+9xLat2<#6y5pweTY7=X=*T8` z_dH_4zA#Vd2w;x?lW8OI!Rlep#3jvax&~ zV3(`Q^87LXOzBLa?n_-{$fV4Ku<=Iqa}t=zmYH?vgaJKzz&;Q8(M~lLeYC^F!xe!< zbLkiu1}ZDq;C<9#axKeCXlR5^&31Yz-@kvKsoRQgUmfVt_wm8+fC*A&wDhOC#20bE zW1&>qEzWQK?YD#=yAgi2W5>RP+d8lV`_qefLe{?Nqpi4`Q!ROZX4IH}AB|U#F93id z(h-ZwjALVC@^9s1g&fs9r?S z)?&dV0#}UkiIRgV*UV&`Yd;tB^pJ!S@+3y}-2m zTN*ehu*Z+jBvc|r$wxa-*siWGgWqX6yISS%O?1U4B>=WglqAK#8NCUx_LA=R-@jE> zyoGjB2O_;9Uj!M!JziCN@+qRHgXsBIbNqp#PPx@B>adMHW+A3114&J%w##L->~Z#p2Z5^>8k<{8k~SlwWUIxrkxL$3^oA9d(; zt#MzHR8}080iR>`@tf|cohj;cEMTc9mH3~M>#;G=Uw}u+C z95`@} zG9%)N9%81vxx04njsT5qu5!x8nqc~Lqzo$iX!pdwKM0E`tA!^R19wL}w5kRvPid z?nD{X#|>15xzOz6`(~o*#JSnwM3 zvsV4Dx?1q1Y=;ty?v@?9g!Hf%@y8^%}Zwu=sJ(a43y`^`F$mitJ=!8tK+wE?0i_6gzt! zk%c<&q?ThPpON){jmcf@bME3|1zN<;y?fsxq0sCzfrSD{em*_g9*>?K4b1B1<(0;s zch~{G+B}esVVRzy4{*v7PR ziW*d-|B-71s2;7JmkZs!H-!gf9gf4>YB0CP>9=t!f^eJzG0$1U(!s!OGM>G9i=W-a zp>k*ED<7XX@86$|FahhnhB~xQGU&lw)r4Q(y@cUHR=>W!geFccCmIRMh9F*C8;+Pf z0JjCUta;*c72YWp@7p}RkbLj^AF%Wb4YpQ*ObO$WCjXPS|5O;aBldC5vl+~6@`7km zT`7T?8)%tgP+Z54A1|JLiMC-e+SWR2ujlgO#m=eWW(8b@JPx%(>WDhBg_mHqRiVY@ zsz~#43~2x&_SUjZL9V;wOAWGS{>mP8Wt8<*j*mj z%-4&dS(9UD6+EP5*Z1`R4|2ND)D;Tc=yGz^EWgEmA4A$Z5-AJ;%b}OCSs#}xZ%7M= z^Kt44e-++J^?6Mx_bhP5THx+e%526u|PSqE_GR}s{Rij&L(Y~dF?Z;imxDa`|uWtEN`s&qMBnXoYo6oFUHK#`pE+($izt^+9Hd5tV1e?N^cpXa(KH_FC+v*2Qbal zu8E$I2@3#{Oc*k;$Vb^Z#5Vt(wmvXk(E3LR%axZQ5CiC$Poo(H)8HjBC4>>>b{deI zs{)&>O*2*}Z?zju9Xo6QJYM{q)yvQ(LOR$;nPdOyjXzcojq0=K&r=*0EF(+~95{g6 z^rSsL9+#ZV-3rJk;JBDxQ7?7z;x#mM^X_80;CPQxGy`^}Z0p3ll9M`$K5C$aPQFiv zAxD&3JM)T&s3Gd^=F66u0!?KQ%J<}4CQ#8p|HRdiXK5VG^|Qt?cZF;M%-J6tFi% zu&`)VCoAijwY{u(;p6O^}2R_m_aKxm?#w1XeHhI~_sR1HO~GQl@!SNs;#n zMH#x|vezI90cC~kV3q1aP=RCg)DN7TK4hz_!5hArZoS8`fB#v8S>!Mq<=$MO?=JPb z0k~HQ#KVItQMPo9z4p(a2&9Y1Maf*id9OpiOwlgch2JR)*Qvs!yxY31Tj!7S=WBtP zFgK#{^n}}^M_*BXC>Rn|6aV)D)dE2462W)y(8# zU(R^!?w9ZP`XI{-pl-=HmkC19Y!E=E5|8hFH{D7$okGD`;C6!+7v#$}=0PHj0;XM$ zN{Rp_6+siPKRtNi1UQav1Nj+uD$k8pCbfdQ;C*7NP&zC9FLOg~ms9+Ab@?Yw#j zh9;7cQgT)c^Ry<#K*5_$LF)VW?*ewyG05h|V&o09WSEKb%gD&!(N3AEp~mV?+fypx*eB|d}7*L&fY(>^$giTh&`F+8)Qn7CSgt}fI`HSSdfrU(AO`|SycrJ zzu$|L^gd47>89NrF3}z@lk<)-@T1 z*xoWnttN^y5s#kgPFf#PT(f3P-FaGJEyKp7ZXf=c;bvW7@Kt%wef;hd>o;srwwoR@ zQ5r>W(tp7zA|moh2$;{*zY}=YLO;oF?%yoH++4OFT9Kft0kR=d6X>mVvvb+Y+xG4? z01ut=;YXZK7b{)Xon0>qqN8#3<;f3#o_Tu9xPBxC+miOtxLnVL>hh_7;vPC&nts{S zLoOsnrHqDAesTpds#u^#hF(bOG1dMK$3XfxM0sh2C!YXzL@D>Q+Ho#Zi5JtiQP-jE7*!DN05eV~vZo#zfgfs&Ummz*3$Kspu@KD}rKbZsJ< z1k`2O&zVMKxb0#e;Fh+x&qh2=XEA;4+O@HWAmtM|lih5j@m&$aDXm9z`}W-p;5~de zA3*L&wQcK?!*Ztt4o+?8h0O*(zI~{J%wc2?>9(SOo-#CnC(>5>t;_%imN_}r`1$@+c`3?`t6bdU|l1E311p(xXwZfQ{ zGNU^tldl&9d?iqJ+N49wH>l(lKkB}+Q}kcnXqsUR9>zI{qnTEk;k4q&#==rKoO{#U zJUIp8oN`Yd`svjBwX}PB_wwBfsqVdvS!GnqKu@m+5LqsXClRzquu+6CAdr^fK7mA4 zn_gLNl5Pc)Z?G>^bfr%&v}xP7Z#VL@LJ7#<|7Bf^V<3EhO7wzn>=VrB$7JM~yZn*tFu!H!h^h;}I&`0L8WW93m;0VXkhfI*T&K9Jm(r_$q1>`|NlrlRtbhqsm ziBdiE_%v8mRTYDll+p4&Qa0+@6w>z$!W)cp=iXc||N8Y>L`Dm}XKAo#9vVOd%!42!yDdMNo>f+T$&CfDA8=s_xq~3F&O`aN{#WP-CQE)K;iiJy0u!@Rc zIw#A8b<6j$;UT!fIXne2%+30#T&E)Z)As~C#jrMvp#ntM`{?zL-OD%8?cCSbR99Df z@CxbGL_V#yAK4-yE+L_e#&^vh}B1D0IUrI>YG&!OuhJ{mW{jNIWpgEuD01*~90T`Vy zJNci@>q^MxbDHd-n8k*LBrB z#?U+yJ(gI+M6yEe8Y_{|5fH_VxgHa*b_lu`w___lf`Kp}(|*f%bU|qKPtn5kcdkGb zp>JoYrURyvQ&9=XT^zp*%|v3&n&j7584BgIBznHaPOGe$UKCmsFa>>qJn#x$O9qtJ zV?R1rX0eu;V4=#XyTD`VcTTL+#F#IqWW+)D-F6ZTT|7|GQDAQ z7~rbK69JhBz;cfdmXrZQ=EfwZmg6#_b_303Yo1t}`_cOd zRz!;)8WXqMX~x;&7-X=c5FQouFyq_j548>o(a1OhG(KEyP$hFaatr>@>`Dy2RUABY zXbOv+!JVywVMKH(KF6VkL=K+YzpPz51yg{&z@NKVWej$}QFP(%POycDzP>vVCU{WK zG~z^_K}&ifvJyR2TnKS-F@*TO#$$ncc-n!`uX1%fp~PcQSE?8A#h?pk%7S6X&YfB< z+Li~=RcF*zTHATN)Z&JaPBc3+Z(u#0v_+Oc$tBrfjQ+- z0H$*dh-G~2%d8s4CJLBGR#tR*U5QV;m`Cg9#lTY>z6h8v14k&Z&~Hn@tI1x3vWKfQ z=ifr}Xg^x65dReIw(w}&wSWEqWDS17D$V3uk5ua4V=O1ImGY?X5(6!WaxCOjyY1#Q zGEjTZ{F7~(;M9kuL-DcA*gkrS@q{BnSG8y-Y4uT1i{M}#3u^$*%F3Z}oIyA_^6dRz zIGF2!qr(VEAyT3VzZ1kDj}EZU9@7zJU`x;Sv`jy-8<1;(%UVW%_=Ocil@0xxe*$v+GO1k zPgn*@zWAuma&|ACm^RA3oo|8QRt=+7%k7;DnDCK6=^XXi3hlLY7z_#kA*}*u4j%m2 z=U*_~5L^M3R?YhRI$y81xT{K<)~eC0$FFF*%$2V^}807_Ou#=GB}Rpy$5xM0sx5Ht^*T$MtYYIhEJqBGxA0JNj zkulpwIsAL;vW;KcDS{`|%8Y$+$3&M3+n2U5GMuK7q&XG)a&h0~7f1iWAhBW3+xT23 zT2L@#(YrM(Fx>!MK|Jdv;;3>@pFU-Z2MiYpV{C&gnrzxF=~e&yH7>EnXY~=s{br{d z-McJD$}XR3dL;IY$+qu+uhqs}KQy-5c=dKY6|5eS^T=f(MbpZZ%iaHY*U@Wdf!)WH z90gn9^A%cc%=V6H3bvYIc&s&TcDF9ZuetYPmKrO|tEbjY!P1v5J=IoZ>P~pv&%Xpi zdiu_TpuvH<*tbCX0v-To!IgUr0yfgn$io;GE?=v~vSpRz|F&fdw_6sYv6n?Fe$>4K zpjoropI`^K)3IBfdM{0UeZ6t-u3bITMw@aST^w*Gh&@%!W4t>kTj~mfNo8fFQE%Fy zMbfPgzbN{6e=ov^&U98XVBppu4rv%we?xI8Kn=zSz;spT&jpz{O+S(oXgBdpF)#o7 zz4RQx!LqNxGa|+_sz6Nx(;=_p7bakX{@|0)X-ug%YI6*<7R+;n9?m~pk!3gYWe}e= z^iuC(;MMY1_#gqIsu&u~j|VuyQ*2J++V#Yb3gw(nd@?_K3N~1Yw{-;;II!^7b?a=B z*AA)k-1+*xZo-G| z-br9suf3n0poyCisZlS0v3C;OqN(l<^?6;2pE3wQK%q*OC42RjjEqmO)n#&QEkJmS zQ%UOg&dxNmUa97~Qlq}CWv?$%9gTGHJ1Sm>hwDKUN;Mk#e(brKEeN(@PGL`pnLvH> z6FM8mc9PuR|NJDQWiI4!$O1u-lmtD{G2Q<7JNFsNn_2cAe z=n!xR7FGX*)A0{Ek<-~2+LL|qfo*ML^Om-Oe9s{u;)D?S67I(W=dGn=L zbdt2GX;NkDM-2^)yQ*ojQ8Exgf;m*upWNQJWHVD~wG>?XH8>CHO5*-2n-MgHtsg(? zf`wP?_w?PkaU*#)89q=L1ca8pieR{X?cuHW5D*5*#fWe1odnZArVyth#jEv?C~eEW zpFQ;Br%xUoQUQG5F$JK!SzkWIU@2lV{#xe74MWD_^V2)%>H9~wzj*PYYqDd_BIl7K z7hGtW>+lIdM)pJPdFn_tSA)|>gK9YW_Pr3T=r2d1>6^VlQ)lX>-ZLd_gTnj-^sgVL zJkx?E9e-WV`!b^D`RV&LEO;_VM`Q@mnsAY~4 zLj(w3Rk!Xvyb@R)PrKAx;#)CnG3x-K85Op0^m^!Dm8vnG%%YCXG=i@p@yO7y0hYJ1 zKVbqx-%x#b3fOl-o|#9^bP6u20R%LGd*7T-l#E*`8l4u7uzM>ydbgp` z(@bE=M(e`QMsrj#>Gky%bod+f34ALIG z6GBt~6I2E2HV)sZxJBk8*$NNz9`Lk#4 zp01A`-OEZ<=308tkStEQ$HK*Roik_hnaIkb*%4OO1-zuTi&-YG{)Sue$dX2*PL?@{ zWqEejtxS5?yoCuolk`r&VxWsV4hKd65eAQnrN$DjFMhCY;FYY~u%S6u#CfO}h=&2l z5MdaXX4?(D%3p?-Y^3cb@Gzh}3Q2fG$2=Y-;%Qj0sR z`DSSjQL&}|iRN(MPk3-unz(@2qCBwi-cgzCLh}c7#sSq=P2aqJeH6OI@O^FDiOQE#qF8gJQ$Gy#2q$>Y->gkOZXW-tNerqGN4Xj=}0VKB(yZal5+1@Zo0B{<3=t4V%>1mca;+9cc(8HW^?`e zCk_E34oZM2M3v5YCW_A`B)AY%Rbf2qPcs4Pg?*L;h>oKQmzUw5K!$_cQ3fEof=$)q z?8Pa>>u$-}N6CfMwVp6H>v~NEsBvD4s zkL-LBxRpuhCZy|$qY~IhoQ!b%7lr?#@#x2Kg}9jGa?z1c`<^+r>wT`a5Gvqjq~4Ae z4|!)8nbqm|Kp7&IibVERNCgQ+A?-t7pz~%q`x2&=*>UJo(z)!LMMF-iWV&sUhKiNG_}IKZQFrM4B)2FKf>;&X-}I8Rq{b@lnOA>b7Sn{;{n3HR%xfxe62U`7t$+-QtL;jxjwCCxQ^`eox59?CS%= z+pR***gSajWv*V1g&$n0Xk_*Ix%WpfCr)Q(DgN=tpSbiv!~*U@s1P@S$h6!X;&;hd zZCsr*0`5y{{2(X#MeYnT4||K`qh&84OskgGW$_OCDjO*dIXC;JVs=`4O3X z^vf>+D-lLYO8!CkKo(}%$t+B9hOI9z$>61o=ckCeQKW{z2n1cLs8siTqK++*e$aZa zHrGw2;b-Twe%-I!m}&7A&inv(9Edy7z}}T|j02DoM+Sf;HP+_a`!!+8quSv=yV8$9 z*e?tPf&6CT2)qHOtV}?7=_frm5gVP=Uw_q>A7&Hme_xeo_CvU5-lmqrY*VIv6>D=- zY^U_Y(h)+`N5;7rUV|^&xQNTA+vZa0S(dS}R)6AChWxIp$O}ZQwbWQml}LdRSpa|U zE#QE2pqo}=hc7%dk@2vvj&DbTQM_mcPbYAxmH6Sompwc;t{P-9;xV);vZHSO91DmE zKdg+s7I^96#UFewjEg2x03z9|Z$jP0cI1X1SJHfbHVNQCo(Wcqy9i6Jj!FjiGBDVd z3~pvwdWIa?5;7JXwZ2b>8cF~RXMHX?iSY5=BqnkZ`Jlf>M=K>O5=PQ1Mb5xubknVm zlL1CDd!I1=ob&y7ZqdbL2uwyZyE7>ZY3+DaEdy2agHe;yKSPu?AcqKkVk)8A`RZrW z!;>OHV!;_GP@Pjl-}-;=#NysH6kmI9ge3Je(Fu@3Kjc*p&w#6`v0?X>idS$>W%c!= ziw46ZA_ie;rCqP9PEksVydB zcp8DJH{o%7ML&o^4W!5}ZVqIoKsyt)ukL?)dv#}JKq}l_I^bWBFa|y!ww)~{08T4( zo0FbSh*uiTO}4#7^(6Hj0c;&I0_Y7_D>FIUKOgHoNO$3h>n$qr&9xV#cJhZk>d z-~o!AO0GRR8I5qsWl#CVA3zz8pba}ybT+vF=P`0 z3VDoLZbF4)ir3u$p&>VO2dnc6SARt>bUdGwD^_Qw|IQl8y!_$A2dPdpLGq-4O?De1 z)(BvAaK|Aep4#$#`)9YGjcp3j5HV6$$5&m}Xc5ld$d4{qGl?)Lsth=RiG3GrkX7nm=c(aIX>L=JyV1P3*m1 z9IBaJW5dJ50wwZb*gi`DZCYlLmOCyU%6PsjK+U%QNLT|PTxQNHF*}P%spdGR-8ly)Y;h# zV=S%S%z3u48hJum5octf`y?0cV=hpVSG>N8Il8CXQ*;cS%K}1p1#4bgO zf+!^xL`4x1K~aPxDqF=G5EZdR0V#qV5J3|=C@7$SfY=ZhMN~i(Dd#uW!tD1u?>D~j zjq{zq&Nz;>DKhOB*;E~OdWH;}1W2<4D3UOn2ZZ~Y^G|83ju z^AZ0YooF7ug||7^7&|+E5B(%?DvS@8R&6i4Bz#14O7_p}I(EaWCn($1@7BQ6riaFy zAPy%U$U8m25R}QlyY}p;&}Og4v8RVzKX&ZcA|jtl>eWlq9=_{<`Nh~z>1QeHX(u(>wr6pa$$R&9ymQbnZ%;LgO4`&3XER?z z7=3wmdZlydegq!i;+$L7ITyC0)$*Ru5*gh5@iq(mI(&mqbI_R;Yj|^Jb7swe2 z;-np&e3Oi8E7=(;7qx64_(yVA-IBh#+Qm3siI66PrQa|0w5|O!*2RV5gKPa83~vVO z$4b%gw(~fTzf^{Fr_a(XHRjdoZJ$1ES_Wj{BO-(YGj^%J=Ot5%cHRXT=R$Mp*xeeH zHB;pFVq4leLYz8!07j`9uSQse&kX209fuyC{(ksc#2V({6Wc1+{z;R6(h78JW9=R` zyMKB>YfXb0z^!g^!xY`LtfKdt=QGvw4F;tL>rbcGq7(9mcj(U>zrH@Ixo|AF#>$9@ zb==A-5H!5325(uO8MY0%1$~pt<;%!qmSh~ewQUCkmD)sKj#?GnQvOQoNv6S@>S z@g=FR?uz0rSXFa=WP&40JLtC9ub4Dh=ob&51o}VizP47mOWpSS>|t*Il%?ks4+ zrv?rkY615+^j^$9A8b|noO3CdO?T@7jF2Po@$s#KS9Z{`y(@gSSckE~n_w?@{ZMcV zIpskijcbg;isp;Y_tqEJp=XY!U|`s$c4}yL@ZiB1mO>O!SIYElZbme{HXQ7SYovKT z+HKWNUH;<=7~RY=Qxdp0XGbeswrtxb&ODj%lz~fV&R>t~R=`YVy9fn8>r$~~T>+`H z-IkiG-we9=OI<}4i5szTpf{4TPOm=S#kYbK$hS&!VZ}fxe^(6f%`5ze;Pfj0JC(TA z1vc5F*yP&C)Eg|MIb@5H>I$%Dc|S|Cjvt|b=u*_?tFP|KE^fOuoW>1+`aL)*^M@@> zhMNNhIF8;)wLJF{toqFJA%Px+@Z4qjl+(~)Bw@!+Z9{6dvG@H$ZNL6HgN5tVb*xv_ zB58|(ji(LPp~M&F9Brh<7r!zB`9e_h)w9{YWj8+rPGHtP9!|g=B|V_dW7fr!Y-2VD z86caj1Y$(AzyK^H+OyWcWyJVnFrNF38T069XL>a;aH#H%NQQvbdGIg34HbutZDCi#zm&Dq2rIsx&$o(_yquPEdy%(n?80a)L$Ya`)$5#dJ2#O8He z#vQ761ixBBJR;l&!V~F|R6=}Q!~^PFvT2k?HVA->Dog$Pl~QU5iv+hK;X@OA?4`D( zEeOG0MQmn&lxmvR^J9<3Hok^52?N3D!bc4XC*xz^%Ysq8yt<;$0XB&h*U^BvMh=U* z!Rlj8MK-_!4A#*QV{eu3N762dRKet=|bVnG)*LN`bJ4bQ@V3G#G* zf0RI355MSo#brI-ac}PrML{~V<#sPGRPPdnC=ZI)2hpY5%-Y`DBi+T-w#nQD=| zsrV2`PWM)Sd#AmofxGKs>*>mx8^b)xX32_YIqHsa;s47H9o74h)Ty>Hd6u${pz6H| z)0T(}00K3iCmB%rIyN@h5DttIzUZ?dcvFb!UBY~ZNuVv=c&DWeZ`6NeY%s` z-aJO#p!^Y8Ro<+?mD#Z?fn2YJm)h(9a#;~u7X9_*ONiO4(tR2-3mxCzC-fIrI^Se@ zWf3)x0L#I*FuIySZ43A;oqc5Kuh%+zR%wX^`tHl+-G z;W}a%Q|prc=rMJ#$Abdicm4qHAi~;y%{5s3`L-}I*1b)vVF~nyjA-w0Llbc+#BJ~A ze`p?rW5?#BC8u22JE*~pA6b7wWkEwa_0aOiZzcoNX0k#8Xi4r^xviVwL#L!b3L)B@ zCPa}px4iV*algd+{xBk(oi(&ACeQ?`p1oe7uTNN@`nwh7{CUfkPWw$F8h~D2lGuVE z89x7`n%0Wm{OhfRW6`&?nN2hG%iKea8cpdXUA6F1>Lp76pqM=aZ2sgnkmzvp_kqci z9{%DFahE)&3M6ICveuS#jk0MZ2r)QsWd|Cb5c9LBp_i?7QKThtT!`e{)^5eO{*Wn= zfwnx9SDxM9Ip+z?hXEq0$z&grRMy`rJtbVuWark_xWMbR z++BqR-fPv~97{p9sU|PFHH)G@AkmPN7xBD$@Y+!+Ws-YU3vYq!z}~Juv7H?AqKFpA zV%9ot6JwQ!C+GhD>SCqooXQrEX=uniO!iI`LYK44^n8}9N~81Og!7y_(Ju1}R9csvHACg)M&AIIoA$TW}+ zz@#1BPLS?inCU>yR#jUXD@+XiWFbw3>-lZiFyfBT=u7#&PSUUWY)Ftc@g z+;rN+RR%$4RaX9T+9<(pjWK7l*ismFx(*rQtR+d0d2HtMYLEWBA!(5g^ccBY4W>1MJ`Hn!X%>3y zSp0#!4ovM6WEd|+1vhDX*9hLog%Hv(TS;w%p~EPr9feZ!vDJd%ZBk zmyi_BPH+>UQ8*bvr?$BeUyaG(t$6&b^ULN4A{kp@JcfQUd1x5*yGRaAThd#2}@q5)Za0;7-xoT}@eP7{dcT{J?xm&qSrtu%Pu=EslN zXF0)U?vCyY3Q0G+4JsIKx+HzCYIR{ z?jjklZp&IUzx>go;?W7VL1!qhv+2-m$~NxVv!`55(s(Wwo2mcVQUqyIEZ`GxA}cHF zlK0qNy(X|lnnO_&t6G?Ujn`{rDzjFGUa6z(%x13(8U3>^(D%g^rVbPqK#oKG2hBcv zTwr6FRp1|%Fg6FSQsj10?V-Ny_BbkO>((O5U-jncL^2*M?x(U2&FSc}=Y)$Mf-`$` zDvZmVfo3uw*9b~T8Ql-28}D?BzyA#Klhav0=F`IrBPSJAhXFMA0jU<~Tc~8#&y1X* zoZk97McfX;p5_SGapFHiGUQH2mF_YLfd!DY6+nA|SLM$XbtrwMlcYtR%P!c@?65EF z&h@D)6Jz~C8cR6mgU0{mZ7VNSeQfg`ySX(eq2a^(n&aMGRv}Tow{`UNX41P86^oR& z8#Zj{QvU(wm=?5R*}q7#Mn7D0KaY?k{S%R}0>WBTsQ&i*?=X!bNg*PsUFW0Ac4Fy?Ll zkt1!SW`$X0=+JglqxtOF4-cg&VQ5Ip1IZD2^}l2yLh+)Nw7~PH231;l zvy2@_7J=l_i5ffF>)6hs9DC>4X^c=XWE~eeDo)z)=6J;8(BdBM*QWu&d`>=**2Si% zu#YOv`p)?Zc^pKZ8Id!zeS)6u^7Zu{@MeTDabbYbvY6(NJw=NK@v?jktNH}ldCtM? zIV2v-mnJvA>{;hIR-MPIE~fr}Hx;Vs#Q$)af_T|eBowND3r11Pl=Zz?sD{rK`P*SC zJ+1g&ak(eyr8)V<@^$tUdSk6ry$6M2FP)PGV^yUg7l@xr649X3OflMmOkJ`mo>2Wy zhfUL=&aSzl`!{MW9Y2l>rh;WjdYV3?d{}4JByK{DdqrH^aa9dp!L)n#lb&MT#*n7- zCS9K3>u`?~00Wq*8#iv$X6m3E&5|_lvzBc}7$!rY9CKW4c}ll9+L3uGZUnH3c}qF~ zM<%Q};Z)+8(>we|2YA#iXEuM`e$zkS0Pf<2q9%V#GN97%PtU3t_wJ-e`D}q{T)I#9 z>2`B=^Ex~A(~Ywdi3CmMrhntkKXn~Ji_E6gr-8Qrwtf3hFq#jE|5&V^4uo0|iH8CF z4&QAm#G8-5{*7&T7EAQ4CuW!Ge+#;hdEO;-Jgpe2$njUT4f#!QJ-w4G)b*wOxar@` z;lh!|AuIy+MT5>4zIQ%d_@p%CB~vTqoR=C;?N6Ur&_xn#im$Bj>D_&TnqU6rw@;6C zxm<@hMwGcHQc>ay?37O*JH3%^X|*zQvcPk1i_s00ZiJ!i`PV;{?$x(%|K7c?N#t6mMQ#1{*MDwPN4v+-KG&K{?q)znOmbH0@7c=}4sM{l3xLc>>a!DIc)>))@BsR{L} zD60P=#jtJgorSuZA~HDq9`mBePPz;CrYRI7?m=LT=FGGlclCrEm+%1J5hOI~-vT=P ziE1V|O_5?aJyWew_-E9x^waq7b8^=IWbO{QmFF%W^F18}>F;9FhwQbJITgd)-u_YV z?$2^ZbxSX36?`X*^MZ`Vym<2D2tAKZ{rZjF@~~3>}+gVsgU z>cJ^gcgBqMG^BsReutm&Da_`E&X_Ugz*rx9uzY?7(E7IOxxMA&{!7w7JK3UPnP3{O zp9y}zFf1{+Qr^1z7+dU66Q0kP9uD!lb^E}!A`BuU^S*UMGUXfH*4{QV|KJKhNFitY z_7vG+DG=VKEh#$~yIF7{b(?ko-+5TEp1{FnvB8xSc!^t=w~`oXhyZ5F)dy*u8_HKaLDZE=>_z{wOKWy@hX&*V%^Zkx(2~NnqIY zzV@NvxYzWbXfrhb9}DiHyR<_zx_5VYUr8GFdlO!vP-tumZo18;i}4@-)7SN%==RSa z?)mSLL1FaYBZzPB{~l>B@BZ%*>6fYhJ#wAw^}jir{!eaiO>i4|rpJ$%KIZ_Vp=)Tv z&~3_70|EnkckkYv`Glq0bl#6O3<+4!GmtZ6=IUSwH=X*q|GHP;u0f+cnUx=EYXNYN zzfSo0;lsW(k8fD^_o#XT*|q4y^e1)h)TvWSYN{orR#8PoELD*ENa%*j1#vrr zg2DmAK@%izBu}y=OwbaY6<=ImS0}1y=7wxw$;ruKXPR9;4r&uY3AJ%-TO>)dRPp6x zoV0alP9cS=rEMOtYuAz6w~g1VUVSq=+nNu4{pFWt^xc5%kI}>1Z|d<4rmGv%4SDj# zix*!&MEAtTk-|aHYNNE`{-`ev@&?vXGjB|1&-by0UNlr!w@9hfc5rZDc-p`E0DMw% zz02}pA!R>Eu$E`&-CHRsJ0-!;p+ldSmk*Rb{`j#!m=m>goy}*5v(!`J7f-~E_t^-~ zkrqkcfSh~KyJD|3Ym}rJ)|A&cs2(Bm-@0<;%CF~MJ!7?rYb2lt$6&Ug@!$UA%A};E zRT!6PQB-iw$qs>uTJE_Jl%H=anIQumC??3f+3&p@I0WE0SqFvHRPiL>F}uzvS>LAn zm}p=bZ5MgAOS zcqB`{&45*dcp0-#&D#jdbTlVt77eMh6AnRqtpD4PBRxIets4TWW-87%VoHQhstkAU z+B#;)#7UDPDO?4Yp(c2F*h2AYgF4&u*mdYoEa0IuR;2%mJ9oq-QwksmH^Y!D0~t<- z#w%KC`kHf#mt+SqB_Zp~;NX*8h*x9jH(2f9E-W{jP zhY1f2h1xF4H~LA8as(eBp2}$jf*s4219hYX9ldORn$882U0h-~P^TR>hm!^5TRMDB zM6#*9{XtG;iiS#EipPT`D33@Gaq5c#1@gccTlKH26A(Ox+ndnZ*#G%WvxaPjP>2rR zfcKI1Nf)}Ue%uTc#h9wqBYIVLMs>g8+Vh-$Iv?0ED$Aj-uC66e-aXdBFxssckWAv$ zS-86oyAo%8l0LcA_MrL+%z2~wRpQSB&LiMo*@t=m-numtx>dEbYtVruHc2a(rJH?NRsS|#hG3DQo$1K8rDsGgDUmz6fOwnqhB3~RATV8^RW;Im5 zC?+4BJUM;W$nd3Y0EVLiSp$<2F9JyY=f}l16<8zwX4O;w^p#c#@=N=DEkYWt30OjF zXRD0#tat<=>@4@cicprxV~KqST%YsCzo$(^6}lC}z7-;jqz~lSx!)>%s4FvJ5+TNj zrN-T}`0fXv`ig^vc!LxqpO2aYpoMfKinS|-0WoXe0)BXbmkMAd-YFi9jqMaQYF>)3 zg^{uG4C;nzj!PR)pkDcM_P%}l1X=~?KMaT+_;TRmd=unQOv)p85Q<-J8#{^8KpN*6 z>S2IQiEU0N?k9YLy_qr-r9-OOFhImB;L=F*Ohx1Va<)|*kx0ycgebP2?v!6lAUt|X zmea-ROz9UJ-la8!TmUu+bbc|8XL9jj`MLnQA|tL-6ZH824vVDrG-L&z#ta3>N)5F8%k`~`c7e62hsZGiBY5ZHupK>DMK6G4b& zcGJ3zSpb?Q&@ZjW^4c^f>-lGLhoNp2vK(A&w}%(3R=j;aRfV;|n_9(x50CMNLL#ku zP!{YWxOc0;$k73loatnm+Ql6{JWiUtP|WC_lxLkg@e5czT!i~5#-C-a!XQZC0AW-B z%A4AC@ja)`P!3O!cLuNTCqPWApMLuF)b{Q^b>=h)??RjM*Bg;ZP@(VNMOe}iEpr;3 zp@4v~CP*GPzjfMzGr3i4f9yUPFE2Y=S?!Vj7`-KLcGKTZwQ+(uYDJi0kqZd%$w!zc># z#T3b5D&+xHLmSz}q(3CgT7!uhhk_;;gA0q|+Mt!Pm8jmCf2$Dc^X11#EVxZm5kxCB z{GhqrXOvGskW<~+viEH5W|#;Uc%>hLASs%E?nbErrSQ-~9!D5>fgvHw3|jimRp*>q zdq8>R2KRJ1eKvQgECg(;*JNt9!QFlK#0e8#(t-`M)S1g0#sHv{KtDX>>9@JLxk8Mg zELdhxG*m2KL?-~TL_P?gSw6@Xzz1c+Wi6G^H5n43PSAbbo)b1pa!wru(KZRUeE~C0 zeDp-gZ*c8jjIJTp<=XoC+3qR5NuIJGMy9qp_An3pSiCocp5dP2h^6y4Fd*-|G|All zmIhoTPk|kZpE4j-{NMl~-AaF2sj$AGMn|J~DM1{G(2c=Pb5Ch6{Sz0@E0sVb?ohmD z*WNAT9`AA9pQFK*7<`@uuLPAz;cq3#hy(3_mBIf`T3T>nCSNsk;leOwK!0gt>0ud{ z+3DJz;NU34(tmdsvnXXR_5Q43NxCT1AHH_KX@Y3}sy0q+w2n)`~F9ug!L~ky; ziu?7Mf>lGpu%0w&lTsrwhg*OGTJn9Y6CEufpBOu7q!QW1I`QxB(pLQES;v#e2vt;i zNnw^T9II>!pm@@;>eR~E}jlh6_1qMZt$qlv6tve09>OScdq?cDbDyww=e;zePi*i`{PduYZxJI7j z98e1PrfK~_JE%m`5JVIFX6v1Wne%rTzWBZ+5>yyLi-?e6x?2VUx=W#}^dFjf^X6TW z=Hl;tB7Ux|d{@E3%*u${s(zi9&tcRtiR_inl=CG*w=i}J$dxVtz0#lr!{Kb8tloPB z(76Vow?z*MABgF3p2{ML1sJBAQYKU+dWj}Z8hPyDLb%H8q-gbviUWSR{kE>k14`|AB8@IS4rQ7dU*wVZ> zPmv1F(Vz)~cA*!Yd`}g%j$V~+8Z@oUPN6bqg>UMp)8oSqnh){MbKm*QA?+rCS!i(d zssw|eWSh0@FVfyOaSF&013#t*ILPA#O+my$xV(@4c4d1?!a*fWcbX$zPRTQs{ua3) zfkHepI67nL@L*g#|0GmrYKOkXxg;^Gby;J1VgYEA}uE=fd-Y>1k;X@F%f- z(dTue2fks399_V@^(Nz$6qkKYsi81c7xTWU9DgAmNjlzB21C0bZr^dAw3%b?7tv|g zS#!9iJA2~;${860?p9i844?2v5&tk_d+*6vy-C*sn2&kCo$=FfoUzmafS=`qE4qtJ~YI+bVh2`T@qpwr_5d9iN!8Q z=VO{rBL^pbB>|Pu?M3NLpc)opzL1SRlgT2JL7Zik zP4Q;plyi-g8nA}Hn?*vdj;kK0=%xmBz5q$@Jc@?7T!VRJzqVXYVlDxMX^`!{$!;A8 z6Z!TqV4dC`D0HsfxY40uijk3Shzdz9ddOD<1_d$hw`GdZsNSEZb)MGU5^Su9iQ>_( zqv!sSh0<^A*h2S7`v}>llySNt=1VyybidK1)R#WOg%%-GNVCgR)%CmUX3Azc|B-9F zprxoEVm1JVn{?rntrFKjOEKxIVMH}2K8u;1JSv6ONxDh|DMGy=Nv^R|&PL*rYVT=k z3F%J!yh?F-(`uD>j<=8hW%HJaBDzhg45#rd9Hnh&HhF-hw^D!$fM#q4~@%R zx?k66yV8H&ia5hYJl2%5nKvQ-kN?-wtR&N&@E$yhWHL{z)(HhJrq%+KXHFl~CmA`QRZTu|nboJ}IydS(K6t?*__e}+~ z^-Cm+(rwtCcu;9mze@VZY{2}Aq;?Xf?S4xC{tuCRtZ%IMz8K-lXN)n>8a^KPg5ZJhkUb4=9Px6U8o#!DAHvs0T>q}%vM{_TFfz0Dd2wyBnPPUKrzy}B5# zsA~c=cON6ivZ?*y(P?2hDL3>0tuPX2sgNgwt-N3DK9`r$k3AdsjdfH7|5-KFH(iOZ zOm50D=f{^8DR!T8$Wu*{zvdD&@H94YvfVZ6(9zk|=WM*E2L#SMG)nPNt-bQ?r7bVz zxA8XB)i&o+?W5~zqq!-bMTi1y)we48rD)RxOi#@~kHCu!nCWn6W@6bjZi+M`C7tE-#LUY3omKhi4xT>VP48z=|27jJ z2G2~t9sSU8kd?PLzwO)qk!|^_^n!(FlzfntKjK?;O4+M^P2z8(&Xkm%;WlPfR|ooC zDp-EsY59GfwJU0}n{NB#)98$@Q69NlXF7CW`>55yQJ<1<#Z(JPDnIv2q2WQotMpHe z|F|-)Uw6!}PjKLzuCR7np1(@Bdzu?O?|!ig;Z`V|-D)a_-n!tlkZ)M? z>cy-#xt00N;tiIo`#8~m6&4oq<^}5pWK~~kPIWc^d!MOmJUptsrdqpa)@(?M|5*Cj zV|!M87=>K>hS4KNmcydud#2HADe((UaJd4tj# z_Z@Mdz0B7)!!+g9UGALkeV2Ee9%=V(-u-M&f!DKJTYipd2JYv@C{PO9Wq!qsGHjZ& zG797ag9_-&vuvKfatgFE?CDWs=x}P@!~(a}fnmBj&v@;Hg${rOgW2PL{paQj&*wLN zRZ(^aoW2bHnmc;*=m8&Vqc8R?`>Vh~eRQ*;EQ-h{?I_kM*ypYtJ9ZE>5n*D`91=Ty3_ zx*?ATQ%Q%li+TBA&hk~Iaei#gN0m+O+4zp0V=Paomj5m+uQVnxEu+Kwlf`#V zy3MgksC549bze2~qycGYSjgjVK|#)z6&?%9uSqE6J?u&D{WvD^)lk00s?gG^qQ7uY zw0N{nDX^`7;Gk_<{rexf!l@A^fN&DL>G$thuRq=MQKi+GMvvm+S81`WFCU8jI9=7? zGR?*UjW>S}pq$3X-n950p5Cvps~;BfI*8M>`J?{Bsr*3D1sD8obEs48PV@g%0v zL5)!P&KF(*BqlNtPw>w@$#%2c3eTi~%mkahHv&F=TV1Wk@1!#XIL#I&i~|67cH~#E z^;0JH8aJ|un)EY-CRLB8Rbm(9HOdW_Fu!mr;NU0r<+PG?yWMb;jQi%;0RzlHj-A=b zN1Q<}Kuabo4C!CCMjfkWO{a_$S1)$S4(;1JDz{GQ1tj?J?Y7=UUw@Q7l>jHQn-^A6 zMw}>W`NXIvzjtr`DF+(&FD6=Cmxfk!fnXk+IMwxDzI=H-$7$(@%(Vu6q)Qq|hEP&? z=~7<7yvJYrdBIe0Pc$rqXCCLBYNWoA5syeFO#0>oMyh_eUx5?2;T>Fxc_W{ubE#@~3!xW?e=cv%FXP#O zmMO@3fc)Lp%eI{_1<1B8#Y4xxXc|0({$2^!%G2a52k>An7wqTfXYXA-lgqs*DhdTq zpf=hC09ryq%_c183}uRqox#x@W^LBVx#RiMA2c<~t>R{U_TKqZT}BcYO{zO^sxMiD|U#~xhtpX|HE zXhTCi3RKQ62R1rKqhBdfcBTLKrC4*(D&)S5rN@MliyQxrNqFbLNK4d$qy9$ny(cg`MV}w=8a7 zn$|hYlqZOVGvkQb@L(90^#(_?#Jh!dVdC?H;Hjq}+AcNtOfhYlpz7az)4y)l?%neY zRJ6Yz#TQe@b9`(spwC4BW!&2d&I=H^HqW^i{TWr$47O;Hk9QcK#XtXIG$gwFqL}21 zqcdA_vWgGS#-^K?VXdF^XUbh_9sdbGT*Fy&Zz!$r3=C zM@y1t8qS*V#+MR4ailU}=wfu2sY!PVrQJrm2Mob?=tzXTfUw{A;EM)uC~M+n`pQdm zaju5y^o*DpwFiIP0kTPJJHAD|_mUi*CD23IYA^wLGPI*R_?t2dedj_F2r5=kjKsRO zW5whKjT+@#RQ!9H z?OQ@YLgi`^N8a}r|1~+Kq`*DX1Y8lM-H0Mk>}8F^bL4c=YPv)`@gDFkOEDJxMD5W3 zfrSJ_`7@T`d@z#LQcEL$)s@!zdrePm6;Tx14BzI+bkeCz+7N{o+Pw}A>y9|bBwTk- znJLaeqOIfDW-rD7^i)@xhmRgTlFN~!)Us-98=7}igg&cnC*!B+`btESuv(TGJR2k$ z3+am+T{6jQ6e~_?$B|$<&|Nr1?@=fOnc5zfWNIhy=g3$~{y82Ao;7%aNqEkEAc0VD zMcEIzS)sHW6*RR?6xkfJ5`zWJEUzz5!LXEOxp>wZTyqiA8|i_8?T1-5SfJ@22E-T_ z|L^BxkzZ@HYd2rncP>~XnuX`KVB=Y)OI zY;vD4%d#}s9H7~!%37xr9_6HjP&WHgD%x!RK882F&otpSo%1Ns0dW~V<3gWaIwHN7 z|8WV&2eB&~Ug7@LVR!X9Gg}G*%3(p7;=h}zy^+(7|D+%;;i0AJ(XCrIIZ=dQz91sF zmF&o{4kDNXbcw=9QaP_iT5rmLpBcbR1XPY6x?D=76F9!&V&!(B3K*4tU>^Y%PaBAd zl==~4LfJHH5wW|gCZ)C{?ZnF{WMW+RI!_6g{w}(Si21;u?E=H8dG|n6oF1x|ozC+sYDR?w*TNn^NYY@nD;0F=v$KW|*b4ccfvk z=YrPX=EC8oepb=Q*7vbLw`Ar(n!@5H!$#$QegvbU$a)<4hv#06-2+4DT}SNI9%9@?Mn!Gmdy>09K0 z7(h#LL)%qt^3^xrWRhBm*0^WOeaz}Coa!_|0o#duWR^bfYEsO{dfrvkqeNS2HfUHwsHwe|KyNu!gya0j7Moy|+DA7P>E&P)g0eVdx;is?PzReih_+h$G z(r|(!5p3gd28EVb@o{<#>c=1bVx4A^uS?H}PTCAnGDUYn#K6AMDbfGR*B38dtm4%4 zCBWcvF5!XG#op`INyAP2e(}?BeeY!jFY5%J{)>TQrD1Zw?JHAqO4S3PTp8&3d;LFn zgJTBAY0!F{O7jP#_>b`}W0Q*Z%&B7KlKhYNVxui8FJD%pXKNyJltX>w?VzD>zuB2t zh}0$1o&E?;aq+eVFV^)Yjopt^x6rQeK^XP?-EL%CFQri>?h+&&{n4<#a*oGz$aNXV zgB7l;+=goK@lHMFMmQUODHVtuTJTRD5Dm>sB+U6rK*ctO~bP!Ignx5Fx!VY2I*j^2`0sF{eIuChTUD>AT|?Bq=V>>*4U3Z|So{ z3_3XbnvLjZCfYsdnL?#N9lb|cYR3`~`VG}XEC97LS|ZZ~lkEGP);AFXfh8&#@!la* z3Q(4cxtutPh{y}g0z&AyZAK1uUj=k=S6tbUpEwhE7%uC5kEY!rY2fwU6U74J7fDWV zSUENS)LLAf%peV%g~B4NJ$COIuCA{8l3%pe&p#}*1o&`wiQB^#-2B$B>w_ki>2{8p z{7fiJKy%q>*$zt;Mp7Nx;6<__r7#JfJADz85QCYV0KQdoYr%Lp_@cKLAGKPXgGJQD zUJ+&Npi9YwLhdfvolG4J{er*gwB=iHmTW@@vUX;Wi3oP-yADH7GqiEJa(n0DOOLY? zh)hQ^24-D~`$+ojhv0{Wvi+AJNB82~Zt*D`^}}bb`nGKlBg&r~Y>z)P1d31J8iJ)x$sW#iwW?YFwOJU)k~{3L;oGDl~v z@IYZC46Li?RllvNv4&(Rge{gNcF(+DgXbPDo~!QYWs83y2gn_w#Up+)>~=CBVlPI& zT6jZ%?9*y1mlaY7SClYtdzvX~~NH`p`!L2sVB(Wo1iYdUU@_&LFBe@g-x5$jMBMqGM-^mSzo2 z$8OAZx%g(1_+XHOUnIjGV%(mD@n6*Kpb) zH9g(h#YHD>!Kj;3;gXzAoIKg#o%1J%CP1IpFVss&175kY(k{RtClapvuS08aSa-YT z`OEX@j1}v=y`@k@JygU+Z{6Aj(~^+gDj-3h+CkmQT#fw}1 zjTn31)TPGXB`ih#*=Yd)Lxdz|vq`1n=|$po$9 z#cg0DC%X%@Zp|EK6TQS@cpMiioxCs~gD`w*e6OZ2Cy zeRz=?apF(Ljl|mTIA#c+I-v?hUGM*0t;+7E*He@VwMJO;baVlk#Kx9z7!;@9xaR2; z(Tmc;5l;U5d_-tH1RGfFGH}$WIdr%P|AD=viK1rOZ%WNC^r=`w$p`WMQxrCHfhWaDuF#AFO%%)&d6^HcvUZmgHZGLFuWOHFdK_T!-KyN2gMh2+c#bPUuZr z>yq^IU4Mt1iUUkE*Ak5})Xy`2ka|FXM6xcM)nY~-+HMBd2nY)D%oVBmpX|r6vfH71 z7iKi~35f8^;e0HVLtR#8<%|d!H&cnn#5r*Mu)p zMHwLX0m~L*$zsT{wXj`#O$98H<4Ch;RJ1tM%r-Dw9j#*q!K_0=Doz_3Kl~8JJnOLb z?uTj2NAQzyFb8E=uuBGILlk$G8hCJ!vEN~u$T zFb#MkAzKQ>r|p}zhA_Btzi(~ea0?@<7>0I^_#d)h&p{Q*T{?VJ)`cC6jg*|yaH9K7 zDQo%Z!VHyA<>bOJae*fSLbf!;#~T96Fir21uWsu;HhIGDY1TL&Tv~)ls7QgBRJa=~ z^0TPf4OJUwpovmAqcagB%^&1R$&SS1L{bvcDk=SiYZE@NhEoNj7)EU-IQ*d5n_Gm$ zKEG76j?d;)PMn421`d^^zf6+AMO?^(^Q_=!NgMJ0?`_8%r}g&mh=y;XS~@mvpBPOD z(hFxzPE6RT$2UuTNdjUmu+lr z1lECt$%Gb{g@O3i-Hhvg)=%iZl(tL2)lrv}?bNf~%oU*z%Fgjg1z=MQV@6dvaK)te zPq(h?*)*Da&Vkziacd12U@R=;@j9kjT3TXY8ph%}4t1f;sI4BonX;;nQTOstE+mPj{|T#1C&``lk>Q!;^$66Z6m=dC%U`h0Iu!^39%wxy(LD!xdV5}-hzu3o6Wq73QxB7$489L}}fu!JR zauKSz2riOTTM|{|#<_St_EO>q4W01KLm(&xQHTQS5CoHf3puT$$Zl%Re*X~{A88H( zB|f}aOK`Mg3bRn{VqEO8F-c%q6=i)^*60C4MI2QTIJoXB+@2mNWH`hZcllX3`!K*4 zG;xealwiu4lqSGXQwg4`$+k4-mj(6I^ij+{FAPco{`h15v8Gw|eh0SBDj5x3mR(O$ zssEOMMv6Y?$|LVPTcGZU8y_~osdFn_*XU7;S&E(jfeJ|2X*OGtEs0U*uD`NLTav0W zjNt6sB|NOeU_#}-%8TioFmp6%f?&LMo&`l`jm#9Yx~ofyEq4~4k&vXaHNiB|U^1J7 zH85pTPlQlF+(l2_HVLSYSB+5Fl^Re~O=JpC;+lky$Y%|UvyFI0yKfPx8fG+hjFhh6 zpG}zO0FpB~G=3OvTAA8G?^!2*8JsDc_UM zbcAOOwq(ax!~<7ow^p1{J?Q1eQ=WxS%%ez+iS+JpPL=?{I>a)zqe@3 zG@G?0XZ5U!+EF`<%(Y{NIj*o7)nkg?FRds1`lH=X6J~t*Rqri7AMdmCu7!`!`<`ue zrY`KW@0RDn)t6$nmhQe?vo5{(fT3GLO@4mMll|Bm1%eV1ZKnCx!b2-1U<;06GI_;u zrXom=>Nzz-$H~GX^%Aqr;O$Rc*_|&y2e+HI`<)L-P9`zN@br_BkdsBakUMCKSm*M8 zFKG_zX*SWgs4$-vE|tA=l`wyOnIHJ+@UAqV$>SE;y-M@QTi%#Qr0l)X6YsK#P6ZE_}=vSXV+ zkWgaKD{T+=fF7E-5K-A@!kAxZpz1R#6r=O`kaU~11`SGg5BRwV)-|g}N54M4mkx(_ zi&Ye9np~CvlwfZ({EH6Rw0K1xwr{)nF_wwI8FHet=C!#EAh@dY95s(lAruQMgS~e# zUpIb{9QCd+J4PB)N}wO6d8H2!jBJyELZF?tSz@SaD&IJHi@P2jcGwH*xNQC`%&?<5 z&mol7>>c!-1f(5fYHzx?3hqimD2fnav}q~KdPwZxWmo!80_@waTIuf_#mkL+Eq-F_ ze!OC7Yk}IBOx8ZdsIuwbT-LeY{Q2|s7T)fdIi%}TUSbBPwA#9ze(fW*4*?nTkEPrn zXqk#-&W($Y0=+!}88gil9mlEOo$t>gfCQBrta(L8II^ICCO5QpNvXg)37eneD2h?G z0cFSDyP^N1C>5`J9&JG1zLR;R<0>-nw9TA+VZ%((<2uY$IrqWjhQjxI*1lLf?P)|R zSm54^(Ny?~`bwNpiMhI(*A|kGm_Y<6=(qAGv3*m`y17Ae$|!B2p_)f0P{&DIMOOCq zd1v6Z^YRK$w0zB3CzIq3=G295DU~n*I_z~-(M|F0NUrd;#{J54C_6uKKW#(ReDa@{ zK(Ig<`_gJ-Xepav7{J{A?dD6xeSVJV)G=fpVA~4~CY?eAbD>4dFd6R&d&yM7=4<3p#5_KKWd>xN81z<|L$;J7)pbqPz9nZ+m;N-EVr*4vu^d+> zJb7hD#k5X$@Q+X^zE&Nb>Nt05ZN1(~GKdTw++^~pJhIKQ^Oj!dIXL@ydcytZ8%sI( zDTWryUhY)1wd|F0ix)|?T=YynPLc?Vo;d&QK6N4C>nv(%(Fk)p%culY8la9I5;thL z6}Tz2Glxuj&gG;}ZF1veDmoK6a8rKDf%9I%WWRATPXY&#n^o5X56%^d&SLc< zD^7lbFyN?)=RQORgf4$G;#4)Fp?<2BJuekWhOmIRjpflY4zqjGo^gO%qD5?&w zdC`gQ>%|gb0EaU5nB))0FSzr~_@7Qx(V7_Y@6y zIz;p6Rhba7=L}hT9)APLZ}>JjM67`muJ&)A8$@cZ5}NscAB=HujBUa^1Bnm(%g;ya zTHy9`nNJ==53y*G!KORl67r7sZ-4TwO8E4phYg47sbi+a?xvtMn>=5u?-pe)Q;673A+&DRv*F?=U-(b^tq+E|qgjM6_ z8|7lCT{KB%a^A3lS7^#bcQk**RhmC?3c~*FP3e>V9wMx4!TjP@(wq zEutjh-81G>O-S z>);Vi2AfO>2+jt-jZZM@>5R{{h-7ZK?v*Y(bSOerBzdYF4=#^FnqxT6V1%Mn9kh22 z0%amK9Ikog2uf@zhV$Jaq;s^4BnA9bZjCs%cl%s98Y=%!?YwTtuudtzv?7hQp^wfN z91*PnJ(Iovj9xVv>%Q%eU0<0+jmETEynhM7{Xk{dY`AQ{gKKn0mnRn&6)mQI4x6vg z`eK+0M7}AiXH^#1%5P=$(+qjr*fB6Nt}*Kid#_p0tc<0f#U!6+Flm>FE(nucI!*In zQx`b?prpFkR`a+r-vUXB`P0#-!h!+=i{36*7=wo~OzfgHXvX5Je9^;ckd@tXZ{=5c z)n`Yue1G+{`k-3fzge#1L?<0Pv+Asm$q+Q4@3Sh!JV={kp!rPdojY>sva;jEw^>e) z+HGy(5V|JQq7e5-xeun7;8#SN=Mm>psJ4T*n;nxk#QTun1IkoWRLO+_$buS@qvPp) zg@LhU4pEF~iiK}k4Z+(VYe;EIHScUN=?L}aQiI`K74{ zhyZ?aSSYH0@tv$=ivu_GpH6GL;iC9mW)3+pkC2;^lEMaJcChYt;#<#U`JRebO#`s| z7$8N0`}`$Sm*t+m=DwkMJ@Zf0A0{_kJ6k7J?3TV-=646K2kJC8_~8>5pyDHYylMj- z!fz^?oMmh?WxY)1bIIqKPiPCX-CKoY@+RgA#y7Mys@LeoS19tEp0(Fhk`s~tYF=4v zdv4c9f@N6G5rI*UJg3T8eLUVuNkHXJDL_n)*N)|rm&TB zNag6r=Z*A^hH9!9aV{ed+7DuXTA91}^uhHNcIT{=${Is53{eX!JLn^e!yXnf#^->b;Gaf-(<-)vGt|?3^P9 zpGRiDj7;cyO+6qMB8B);|4H&^uAauKWbenRDK~Ep*8D4imiF`Av9WLQM33Cq+xkw? zlo+pVqnnn%<;H~eB~$r9R=+#0_ieX7-;(`u%f7DVR`HV_TTN1@#ZxUn5O|%)$clHX zey!cVe=2irO|5rOd`(zA}7cF^YE4T_jfusO?UN+JCEhj)DdXs z1oL!_SBaet5^i24X~%dKs-y4Y{iBoY`y{yBxt&>ISX0o|`eKbNmF=IV5G$Ro?>Jxgp5-RH>SSw9R{&?0Uxn+h z{OqtgGin5B)CBKhw;l$PBV#}*HADFB9ZjD_*4XWa4zt+cw@gkV&2S9=YuJ*kn@7X@ z$G_EWT87!rZ=Fjanh-;^rYBcGWa0dZeaem--#$0!mwvHR%(}v6BQ`m&uare|ILl_h z4-~mX_wd*6-dT~Rk-F~$?ZpygIWGsMg*Xb3rOF9VjPQyg=gOKIPi66+y=WJVzA&Zf zYl~k3Auv;)rDS{Hmv9_@AWt?Ca+96cdq*O-P+uT~DYpUftv&!36h9@Y#Sn1I>D z%D?YB2V(03O*xTlBVG=H8a6D;-T!B?ED>@wyBZ9;+dqE)_=lz{Y8`n&#>6|u@*0Y)0d0m*e0#kQE1BYuiwVPc^gw(6~ z*OUAw`^C^n2#VXEn8|E{O-Txv(yp}f)$7|FynfgHbH9~-AKM&$Zu>EBx`gvk7oq(2 z9Fa91Yxo~FPMe-s!5Eq0Aw%vZJ|&S3my)@$XV0D;1@;mVFyJe!J|K7WMx^e~f6x^aw-5Qs_JfP6|GB_;A_f*%}x^cD@E z9U8t|$!C3=^4YtHOlTY3CcEf_eT4GM;YzF_Zd^3DO=Cn|hxYkl>$PR4AI<~WrWdwH z*>QL$R3P9A&W!HrwQHRL`^c+&3(yG1l1sd3q(Aj(AF<(0)EsZtjCAiJRzHl?t-V9$ z16PxfkH^h*Drqiny^oN-IR{D7N%$xiCbnBa_hGWh&j(?W!$Tx)n0>8IBwha?M~`JV zCqO$X2wQ<1DR2I(wba=8kt@~$JY0Ye!W13?gDTEz6CYMVT|>Ic19Z!bZ~gdz8;lbc z;=o30*>0xD5W5zlFiWyTXVf2G;2^{!Gr{>bd`Xvz$%pm;Z$q{Wh%JmQz>K zTzOR9d2G_fi}-_&j$2>wps?YcJ|wM7_g0BwVk%bq%8u7Wgg|eH7?)`0^CIbDe%$CG zz>%s(Z&UWS`b*E~WrcEJn(+O&Qy{Jbd2RLf0w~Q{A6rR%tCG0Q*O)`+pqe8*nelOO zqx#`MLjg>l7z_!@(V&H|#q;FCK7zix$~im+~n0XbQacRmBXdK=nw z((h9uk#HHAaAn@9m)OEu!Z}iEUCRIqO7`=^GyG+2!wNQS*szEbTxeSj)(p*&;}FIAWk}ig5W%qnPjqb7 zaDfGO!SP?~SyAex+xzW9=hz6MJC$)L6|~qm6r5Ik+MXT&{|e;AZ_DK*Mi{5e{NY=# zQ9%8mA~zq+DdIfSyFfE!zPpIxO)%0jsD_c2;oEto4-j+7>*tlu8t{XVNnL6xw0(C! z>oQVS#zwYtl}{Bt6VUSg+vf)o-G%f2;M$C-7g`v4jYFHmj(#GzxC<452k^ zhF)dR#;BNHKzhcLPu&u$-aQUpb^0t%VFyzV3H5NE!FKU;!=!lRwHfhe<+M(gw*kdI zDs`bp5&Mv{O_dH?i030YGa_?{DqR})FAiai0+6&4L|XG&AG~`WNMOIAb&5#2SV)AJ zR+2t0;5$V3Lb1Jk$MH&I_-{O-HaI7F;A{`Dw3jI`Iumc!u(25-02;;lmHoh)|iO071s zj7p&V>x}*CD928TxM0ro%w%(+|I7Jk2#;`Z9HJR+_QTeR;u4x?;xJh$NRn_cVR3YL z{{bVr8642!cd!x&QI0!63lfbPXOHX_Uwk3j3a`31H}0cGj9YiVkIV(d!P_h=U;!(P zJ>*1?%oSUt%g43}O5(Q|^_>*TvPJ9aX1X8vnXPCszjQIQc|VElNepO2_r5IW zH#U&}1*c{W-MX}|I`Dr&JRY4}2C0A-UZ-AqHZgFQHFfNH%umDOZNFPun7#(;9a3oqwnbmU0iB*@7-b%X9 zv%%F!ts}33IT1JkqB`9mmOIUc&Evz#?93kkZ|D`IvR@a-Hb6Y!RIu5VT5Z zs^#;Tq?4#v3VJ(O^O`o~!n42p)!gLJm;(0zJkzggarT>Im*0#cnfye+SX-CQo2keI zIm1_U-}dGmJOufNdMu}oG4~=4RN@0C6M;>+dfxQTg2D=Aen{pbnG%ZEz2t0wty`sk zYD{y6>Yl?dhXyVCQYV6ZQMe>hVjQje36}XeY#UU$zDBr~WthLLg=)?#=idA~>rDd} zdbd|!=h1OOw$S5M!oCy}QGc*})1Nh1e6eLYK&RMf(K+0*s*d5T<|&B(R`lsspIUj9 zP65(`EYkTc^&#V!TM}O2$)3*{>C#@Ee4or+7FQ5P=`N;JVMnLbdOxjV2^c1SI0+L} z?6f5MW}5e^Kj=g+QOCyC^Q_oN0fgZzr>8lQ&;SZ-4yJ@YXuiRAOC%9<szZ*Vn^$xOPs7E_-ke{cB83sFngo$5qj z#BQdYybmUE!h|!2csp{`mAELqee-6B=8XvXI5dh#UKn@3TVfqTCZXjpY42A*X(zT_ zwomu9{qoDZ;8gUvKS_|?otyB=+C3$E_a@U{ba%GBK4 z+}@8WAkA2lRtR-GafcoBBgx_Ejs%XF3@@!42ZIU4;_BKpL(}UsrMuv3xxvrRnOG`a zFlsaZqL*acfS~e!T~f{Vv(OK-xE8~ooD{yG3g(v(DyYBa4j8|<3szY|_(3kwDb);h zzBq*tI;w7N{_cwi_*%H)Rk5)nc}}0TUBrvD)%9vGjkPT#qPORmZx|#Nql#&B3ClLv zezVgKwgow^rcHrbrxtJk#;tU*$nE8zyX$UpBpTVOI9;A<6?#JCL`iGoMAa4jB;&@7 z$vzwjCsv=G>kBUpLx$|k{F4)aq}CoV86uH5=1Uv>`p|$rY|bn-J0NS$qu+9>D?@pY z-J*&-wo|GiTbJH_(uWtpZ3=Yp&%mSh(Er8TnFr*Ye{a7Tvl?T*V_&kbA;wZs!i-(Y zlFHIz52a|crehjT5~m=8z=VFWa`Yn-Zmdn7TYd*gJ^* z46EronMW)Lt{f7Wxed;A^j=LcW(Nqxl%}LxkS_>$uFQ$B6EHu^nXex+RGfMs9?5Tt zd7;0HB}_)K`Vzzr&D0;~>xGBN6w00vSI;qBa1Nm<|E-nMQ)d~7va3a<){zZ8;Mz#| z5>+E4yXWiKLmlLFE`Bwwq`V0~n=)j#HYOJ+!o$gxuth>?AFX9-)nB)NZvtm~gtkc( z?C)GxLkD?liO;5s$XE&&f^TCdij%Z>8s46BfA|c#=t!6tG#(d+Oa^WRrWLDouHI}T zNkuu~2?@XLK@&C+XJnR0Y*${qcrjmZgn1$Bu7>W{*i}_fAYNeV6}fgjYgbfdIjF~t zpAJP(+{M6A$8EQtqbh`#@7cTe0lBH91XB71mBhWg^hl}>x|?ee%reqH8WLadODrcD zM^Rc(i45?`HkIekX53U3Ywp8RlaEGpOlAg%l1BPwbO;4UUW*WBjQNUYH{?Y3%d%!1 zFbJ9wKf!S-nwb6Eg48MpK-D8oNi>_)m`8Zbf^<}Tu`wz#QVbSp8IOIbv{g9yj_oE4 z2eBp>DPQBl&pK1F!+y~~GKNEP)cfM^%?Qq)zr-R`Xd{%Or(cB~8gpcu_g zXUF)}K@HWfuZY6s^Oa`vU5K|f>{p5VGGPO=HPA=Qx)?PvWALEq_0hyUkA01U&A|y} z6+G~&5n*Q(5ktylt5K@zY2IeiS;Q3e*1NT!FYWgH{=4XkifWK(TF*W{s%=g-i4xH_ zOU8C6@k?V%e2uBdCcl0A3ClWB2&z1gHM{~c!;MK-IiVy?6{;cMBl4LP3LBCNd{4@( zH>8doO?B4NCgZU>JS-p=z6TW{oOdF;CjSY2{>l_V0I~I8;Yn3B1n99({nl;U%HabD zcTk3T^l4(oIDOPT3l5bzm>A-3KLq*dk)TNWvBedY@RqRyju36PKK}rt*zl5X&Q<`! zYO^LoaCieYA`>T4h+%V6`FBE~(%n6wCTofKu-tg-vI#Vak2=QrCYc=bo4=HJvCXK8$dF#OT>QFq^Gz(E&duRmNKA)!(9KgCWczMxPOF zb^>&XBq4|JJz-$bk)iWq!#T_Js)yk=(3KFr5ANqYAwg281sY9z)dARjJ|81AZ_Xf% z2@@=7@w+q{zNeH4CIHCKALyyCzWBtiB@mypFyxf{c12e?jR_wMEjoYu?J1adVk+sj zIl&qtoh)$5i`!*ADL1nHnj5UI&aGaWKE8S{UeVVZ3a3w<+I6-94I3OLIUuvS8AHRb zoaQdY_HeGOf5dH9W@PPS@F$aTJITG^iWt-0^Er{*$dMDm5TwvK6#YMpx4CM0lEPT1 zo&djC#8MSwC43N)7|zSrrO$|gk)gsGGy)Asgwmk;7NW*m1{{G4h&w|4q0LQ)Nu=3GH50ZqUBAa-Ke%c_N$XalW&;4 z%RQgbEWu=X`l)N$E(y@YdTtF5d)?nBuN(SNg-sO?xDG5W-a^ONIozA%#q!8+rhS!- zO8j~mxm9)v-G63kIf&y_=w_B%iaJ6kDFtRjglsO$5sLpgZW5&U92d)oiZSa$Y184V8gS~6G>e7 zBDBI50RzqFiy*V`Sci`o5zcS&9m36}8vFT4znjky5aH#7u%kTABBRYxpUrRGXW2** z>mL}H0!YenC}!b&M;HUc)c-zDT3<+^X@GNrt#dxrV)WAMKY8z{gw!stpTp4vmK~>Y zY~R}BD3AY`nnsqo7#2mQI@>vuuk6^f=Z*|(L6=F*g)7eAcZ^M^ogz}w0e1`Cn`a65 z@z_a58A!=FAlB(7YYyNvBa>Ru?R$-C8aw=_iRkgnvm3W6`S3ixH4`W5jCpxkz|5_E zX3Hh9RGOby!xbR=KFZYdp=SRaGJy^}Jr3(&p0cz&C}Pf#t&i2Lg~WJ~u|GL`q0uh@ z2LM^=&(IgpsKSSiE`?;UVq=)lOn+9$mUse=?UV45^&@h3oTi3e3znX#b_*v4?aPcA z5o#Xmp~NKB@(xqtvM@LQao{7g0!Rwyk^oYV`M-PrzF6nbnK$q1>O@GKH5qW`a;aKU_@;v7F0K@uiHz_)M}a7~ zu?I_OYN6$|0DrcQJ$m?9`C(3RX`9=UZhNsK4L}*LQ)~do&inep<}jXOr&xiAnM;XX zJ1LuI=}dyr_dxW*Ved;OHZS5Hol%<*XhlloH*enjkj?pH!;9Hk78JO9d`U}mP*a=X zAw*$PuFmg{lW>X3C773hYLW^{UI5d|`|W-PY^Asn{x~^j+dJ`?{mCRNMZCtvnf257 zK52AFPD?8EXuB^*U=U}JwvOJCVV*zixj*1rfR7U=PV~)|8iz!~v=*{!20A=vDH=MG z3dT9p4_}e1;`gbIoktLcSYq+mns$qvz-J4~pRy}5WkJ@InRvggWD|=LlK#8LKezQ> zbJYcpFklqI@QPEh_QFmVC_rRkELuf|8cy;QfgB@Zj%hZ+z*FFZYBPVGH}8#rVm!y%G8TvVy4sa^XAa?xqJ z^}-{SZ^OQ9x#O9r{zOniaqoa6ZyUC!b1}e-F{@mJ%Y_&hFwJ^qoI}xWL3*T$pm8B) zD4{(Oc@33QDfSkDhl!LJjh}BWzX#ZPK5Xc@h%Mv|rU$wssctx~f6n24_d;N^P-_cXJ;4T*zu<4>XOfUZoXFBNapTDA+yB7I_0mH#pCXX3NMr7 z(^!Tkzn4Ab8jT4=k`*oQ(agRJAt2_G*+h&c-mXQ>RSq!TqVe;^Y&wK5!XyPOu12fi zhH19#!wcUNtm+aIAu)(|NXTuv6B)(oGQ`8WH-7HL`Mid2(5@;iRU)jE^M(v#s7KC$ z<0<)tQ-CHG z5f2Xj&H$lwn;-oK4H&Q}|3<-TYsx9$vioIm-OHa|dXVI;WwY9>?eFieB(z+!`?r+a zGFTECrpJs=PuDl|V6o^9m`=GRuEPu)-wUWAo?jAA)aE+B=6j+A!x^qOAbX#!Mp+iT zgi2Sym8H(9#EA3Si&foOwY=LLG;bK}SYlBtCwyoybsrq-pJVg}B>KU2`VDha~BAWFSR~?6z*to+^ll zZ0H>s={_c%V~UP(9mezLpYZ5M*DFjtChmTnwu-#iQ#v=2!%I4Hb^8`Mz@RzH*qq_l zk!>G&g;x|F@VEv=th=ms>G5>c-!(hRJ!<>a@7%>MTe*uduL*x+vRJ{v= z$Kuu_R}22=!8lfsjNEV8%c)0Jy4?6`_niAHw?C`x&d!X8D_E4826ySn zMU&d}t-VZ`75=|8z{=`od26Pokt7VsDj+|&lv{D)(4l#u35q@w?{+_E)le~aOECn_px{nqZMxfMjn!Z5Q>di1FN8$VLH*_;tll_3dtwQsF- zE1;#G6aD=ARazTu3cuzwPRY*BUU#c`wZYF$jH)(G)C-x=ivZU*pjETzz6bKC)`Q+B0>y|;z*%Q49gd$T5Eg&9%YAvhXj(xUY4|$E9KXmu;FD#sW z86|w^8n?@K{F+&k&2i0X8ge@wpCMq5grIwMaTG0U;GO_6N5^Gz?b(`7h0I*Vrq0g?4R1&5m^P7I{cy$pqkTi-vZ$|p*ezAUY zL$zr(9#iYUAWgHuCi}Q~MQLjN^HjUbBZs_wS{tobNBB=j|IR~2pyM%kL$i3epm#zT zS@`SML*hPh1`qkzYI1D1~h|P&b z5B-m|9EYog=^`dyspTN@HqM6Y)|DAflwi&j*xH(hL_ACQ76HS=&mdDv;t=3Q|2qw8 z2fN!4S4;MI-z>OQzU24;<$x!x@m@@cnSo;1TK3}QVrxEFQ}g`aTGSM?=FBlWK5A9d zt0?yG!u&*#Bq@wjx6u9CbGNVfCTmxjEe59z4g2RG4a=Z()?b+yQ|g^;>@Q#9!Iu-) zQkxsePsF#X>ulA*aPm787cbhlvQKpO{z_pGvrrluzC>=KHP4*&Vj(&()1 zO!cy9QUbg#Yev6(+H=ziBkR@wR98n0O*yZkeyXWG45eOl>D9hnyXDc>-z5&}hyCQE zF@~Fg_gzXN2BPcyWSi###nQYqJ8tlAn#LNo8khfU zzi+m^bi^+CUP)(K3|+E`W{_M@uYQhQi& zYwc%gD=5F0T8sJ%+XCR3f~!_KI7hGER* znpbzdue<8(+E%4Yw{HQ2!;o^@wdPIsZxx>zGhjNz7gEfMf=BI@{plsbT+6|B|J=Ecoy&c9{E5s0xj{ta+Q7r(d&Qd1gq-Vf9O6^a_tX zGc8duSZMB>ezN-tqdE!(l-dEDgvpTYV_?mdjGH$jW}l_{R@w}bEzoS(gNER>Y1*&JMfj%=scCPpV7kTtH)9%mc0XZrN%&3$%h zLH4>Z^k)U;?NW3(^S`&ZBIV*w)F!Sp;?Z#!@#~b;=;?%zvrRPuN-&f$B(Yobu9cyM z^I@kx1e%z(OUkM_#6&mjgx3wvfzRx(*4+xzA3b_>4FPwtg|&TDH4|Sidro^z%9L}2 zXP~UAc&e`oslY?SVzC0x|8@2zKEoiy26%6Mwsd16;{vGLU0x~yVIq(faVZh=0i+zp zYV8`mPLli#8gzTi+k@N-l-_)m%ha%lg~QerUQI~dwpF)fk3M#Q(F63l z+&N!Fp~#PPE3Obo+dfLU$Evt>>(&r0t;~Q%90t}o|MV*fA+LBv#*Pmc#MG$*<0^~e zUN!J<=u>hR(!QNtWv?AHwiI4P`fuUtMU<*j#wb7^=TE;@NITpb)B5?R-k`iSWXVkba6_ULA+t7Rzgd z?J^a&Q%XW?noTKNAK-WMufLMg)5DXV2?}o9zM4s|i+GMNFS>l>ziU^t?qEvYDvN7w zD4AFWuMn`UBDa}(wHtZ?Bhxz2MRFH3h_m9$ zWF-TF)#gUd6|**9z~?X94-}$FaRLX9pi)98Vclh8jZR?{@p1NB>N`T<)XJkgwTaDj z=~J)Xc_rg{lx3l-Z601YMXj+tBb}-?6DDNH?L?=c;A=Ph zwN!(NWKf@uS)aRrg#5zVmtf9rJxdCrqwz2K^+ zR$Aat?!D8&mhw@$bHF*(gQi%c@Ir5g9BTo1yfeW)q<(zUyJx@dcuSeHL?2kVdL4W_ z;WjdRjCI3w4<|$`Bt{8V$gHG|vmg6xet6@E2};XIHps7ar=Ebh-3-@t~kZOd{E;BS9(h@>^ zgD5wr8tD$7MS?K78!7iL?eWa*y7X|AW40Rv8=Cm0OP8`lmkjG9&yakvSfAbqA9fX3 zO~l#b_s4epP$nJKrV9^V$b8OGYVaUNlJnQrZ^NRNNV7axF*D59=RE~t^QQrJZ2u`B z?qL4U=utBFAgjEf(Kxcl%~j(;9kxC{ zH2p+mEvVvO4^$H3Y9fkr5oi!3E_=;>MkXKtKG2w~OFCM^8^|3Wv4t@@!Lx5T()ucA zt%{bd`qh|0N=83dIULT!95IX%{2*hV`A+~~607O67M`k~Kx!?j%9Cq9Kp5eyYSE05 zC3{a@mN|J?$1Yu7T+j#B494Dz*L1UZg$VB;2|@!zZ`dyCHz9@UgE_T*l=0t+Gv+ah zjI7}@xirY0H)>_D!7l!#_O7b(VI;qcG0m1SNN-TO=xlbROlE%W-vU%@-VH7;=#)6m z6_u2fylSkXHRdPdPI$}vPmu?!Z- z1bD-uCvN|J@$^J(z?=vuYj*ztu=z^)ZAsbw|a_O>2iwIGI84+G* zzgi9z#_7deQ7VyFFKtLB`G!#5U8|vyF2Pn*j={bJqi%a=<51-`aW@u~E+`i?|3oW_ zLQa~0oG$P&#k8AXWj7@1@qLZS zCEG*PfmeY&H{)|s0j{5@fBv=GLPU^^z`i;qKf$`w8bRz zVE+RlHSMyVK*j82dvr7{^}vdFGM~VTp)+|(WshRQ-{VU@lhQSt+*T}yRDCuo4(YQo z>m-txfz5igM4YkC&8MZxPTcIuPYNzoT3Q-M>cgfFBveW5S^T98pL0~Y%4x;P_TSDQ z%A^Ar+7XFLBysp*V^Awd)O)m!w9~;#Fe%)m=Y;KJc_IXTd1xFD^2yY~LDE-bZbLED z^!5vQk|s!D@n}Ft<%$u$;(ny*L?J*nEhjUDg7^N4x5RyEFo-xAj200m!?ci~BC=Ps ze2!=?F(rvfcPpri3RsGISKGoUL3;3(zEPVXb(ySIF)lTDpD@rB02t(SWOMvucXu^u z?y?k2cm2{q=LWkQ?dY-6ScO_7f=wcuNk$GW8d3IHBE5ej-BpCzL~2XscKS=p7_3o{ z1CC<3SNUN;(YC#II)tc}*A1D@gsL4YVrdL!lJ`H{9!kuFe0G*5yZ)!_JQSyDW9Z2&;rDN&Eod-DO1joh>9Mi z=wY*t^fZD>#yLEn(KI>{*SvVYI>0_rPF5K}Y$?{C!ZPE7ZFP4V{6_DPSi)-w!${5$@O{@?!ertg4*mg&`he!g1$hU~q;4h*a2PaeKN}Dr zynT6a;x3GR4XLgfn7)grrPYz{zhM19lST%6oym@ex%o-kg`STTLIqn&dVsXOAWxnT z?)P$Gxh^xUf~>FQG(bYZ#xrpa4&B0vqfJ^atP^@w!cQ04(KMZ>%J=MkHy3wqA%n1E zSCvh2Zx#byX?n4iOd^)wd;oV$IM8gem=)*CFT1?-64!45v4O7BUf0fqm;q%3*hh=Z zNtBA=Jk966x?WA##|{-!by?7)44#SUz^-p-=af;`Ux_U)TmBxW zFF+uJQ=)CpC(j9-<9F#%t;=c0~sAki*zY5F18%hl11RUyX^X zhsu>nH(sCgj5BrXC9yn*q#YTYw3g z@41Ey%IuzUQ^KVJ|Fp|)w(!ZPJ`vG38TX+(a5w8@b#q)qG41M87(5)(Y6d4O>MVJ` zWOQI0kyB2{nX-@e(FpVBtRU_-YqmWA(N*rcr zW;PelHo+6UDUJ&hM)ufqWwfc=CiosfN{?B~QaJH{6WrzC+htkBzmXsucwSnn0S#l9 zWl9-KOoF;#-Le%Xhl^V~P1~$1P2>7QET1kzHxxDMTQ$)ZV}*1$mab4Q!xkdU86lJl7D_rs=6&`<52c zaC`eIZ!0m)_(SnbV!lQ88YXtI3}#?3nzfOtkTcj+h_DD)?{j{BmS)onxM$ReC-3cN zd7k4MBqHD2f1({(s#LQO@tAZNowbn!NZ2p{9d|m~CqiNM5~3V%&JZLg7TzW+p}#oP z_4S=WSAP7~y}weOZ{?p%%!;I5itSmDAYoT4M*TYMqLCjc{85F)?C!H}+ERp{DUKreW`*%ajKj_kg1h7l}Od!X!YDXgiY_iB~#32zg2}{9+iW=57HbGMh8@ zY8|AU^-9;CgZtrI(Q=+IvNhPGH>)~3a@@7h&^!e5~v zvwN~gju12!aCXUWh1Rhe7PFv7>u!FzMkI1s$};#NmqE%*#%)Am`JL|Oudj7eDzP2x zLqmQ(+j_6B-8MaGVbNsR=R4-*V48p?z>tzxW@t1oM^I0i%aqxpupj}F6lqq4z+6!$ zjv7-wLRt;zu!--w&@Dvu%jfk7ZEr(f$vv5f4OH%h-kl}gkY_7W0~#@)BR`AI7Y`|w zXFJ<>{2~aqo_g1HnNtBmh`!tVLIOP*Q(3oA6$}IOmUHBanqQbzB0e*fav)Kl$?+aw zY!?AXm*<6?ZI)AOwV#zYy1^{5xhY;S&eBwG4dZ|{Nc!V9wob}vpi}=jU%n#~)Pk;$ z57-eDIAI)>12~QLaabZ`2aa*c;4zYsG*K!cvMRpJhUS|L?X*}d_wIaoM`IjQQ^TsW z65rUuj-g6(i8VG;ja@E&d+-P}8Y1E1;!iBM#j;HKVJGq2)QE|Oa?cvR<)78S!Firg zrXho*$?GY01Wb`>WK{J2qg=arGKD8|@+_VVjZIoE6$=$ZMLR9jiP~!?_)CCTSX}e_ z+gGG+HVkYAjIIT^s##s&%QKRx{g9?`D8%`20i$Z2OMt z>bCrfc)epH)B`Fgzmi$s4$I^NCapD)laL#6gBQO%>(|F zoCvAk_t#O#Uv7KrRA#)BURp3M$m3~A7oqg&*4)Wu6csln-_pIax6wEEsVrHe7I{E< zD3;2Rs+i#JXMn@u2!TBVMn|sX=8W_%cZ2;(J?@p3)Q+Q-v9Yl)AD2n{lq840^q}Nd z`lNZPb1R%%+6wI^_IT%Z@O@FJ(Jgws4(_|KK1C}=W6^7phBcgY@+ItGnIHB!G=ThS zQ}WL<&bno7N4{COrn&<%gfu8Ktb7rJ0m3>HQ7vP&=e<>nzc1M(&u>$5>Vxhqf>Niz z`4D(Y8O9swb4ZhX$0d1A28r&YpgHQXY)JvZGP)k`#G%8se#An--A%`6(-4JIrJ7o( z4Q^>75L>!;*d8Dt&U0Q!F7TnC+8JGQMuDXZUd^LL8g2%+G1}*l*lpg2d&bMo1-Gku zU_UdT^GUHPTTN=KRMfquDe7}vtan&`5d$gm=>|eUe)qUli=z65kBD3uI6rWak4UE#tF`O#*D!~MiX_LhOXi^d zB75L-NCVF3d%!C`=8-s|0P1IS?Rtq^RWay1I&wY1P+?H`QSV(!XNJdCBy}=l<&}0} z4`G_*or^!nYZO&z0_Sp$8BY77rF+{5Fp{9oFIehONI|6)N6}2qCs9Jm=(?$BUAij| zMrxVTz45rX?*o)rg#jlymf}S1H`#oB>Ikc=d8L-As?KEjtLkdUpdS-BmDef*`4F;I zkTW^Ms43wTjrQ4SCiw$V8?Vw-a#X4Fp6}36Mht|(Bu3gM;Jx+2Rn)rG9H3urqAe# z#Nnt%6UawnV=4zHAO?3C`99q)^3O7x37VRz{ze&Bua~9waCcR$wE&ua8njYyf<@sa z$?4A?v1RUOGyS#!v+QWBG2`ksvA$B=*8HfYB6BrRqBfjTq508v*&3S;(0DZ#fkNUr};HQQS0OvMN6G5akh1yXUK$F zj=ECbS-@yJ8)(&tcZ=OE?#Kb{E0avJ8wl`DG# z3puu}+qTvAITQxAkNHx*)6U6EO{}5S#C4@4ci*40s_5$!a4L3V3jf%Dz<_Z#%lmxv z_%nQ!7*i>1aM-j-$4AjYh;6s+$m$AI^y&0zaW4b~@z^Qd`)GNp_p+i(6%Iw)TD|&4 zxol^6fKCo`ZEf#@isx;oAL7_H=Ol=CZM!0~WcW-d`$fs=77mT|%I1T+E8uK7JV(B*nBZ|Zbv2=QxLILJR9U0s#BR-v zG5EqqlsQf$C6`pkdhFhRTdbdTGjefRWt#sSi;qJG4@!CLm9}AU$b@ZKEs{{ZOnZKI zX%{4I6zs9=jC&X|=0>JpJwZXU?xWAl+$uQ$97}Z`e!TRdfU?(`OOpLgA3aOCblhl0 z=k)x&kawXkrSeu5)#OAKD-j0w<_$HC7=UkE9riB+@UeO=Pr)EDZr1w}6L-_U9CdEs zr@b<>eVc_X%3yBKcQ*D9Dd@+0To|3=8@l;`scmFT-i8B{>{M0;Gn=o>N;|@1BjVk0 zx}|B}7`wzdYOE;U{&eGiTt?E@yqX@9R)*i-n|Go3{5#yp)lT$K7K#053av}*(f=59 z;B`s5z3#J4VNH2obr`%}p5rrFkVn0}y-W8lx%?i>^~TtEe4ICo<lLMlz--a8_7o~W2RS3Iaa9W+iWWHyK~$_&80KehuYIJ zNMT+cz~?Js#G&Cs!@u0N;lsmg6l{$jxRqSL%G56E^N_F*1g_74q01f*E-#~Gi>TE2kTz^?mP$dDLKa2%$WDNNI0N6PogXncFS5bnm))%CRg z!Y!%F&3(>x4zIn(^j!1!@gHDezW@}Pxj!B@jrvEX8Bb#L@Rr}*^vEungyij!@wLWtqgQ;zt> z4=wEO?zXVDaGeA^hdf%u_|f+fSrUV9qT|XO?djB({-P(u2hdSX4zi;Y-F6^JK|uGAHqIXPV{?Ju8|bOZ#!MAmgz7sx*24E%->r604ZLh`((dB zlrUa7rsEluT>O5$ufPD<-cP#F2eudi{xFEp={`q zHFt+m2ht06<;ujTORibrc^)=ME(qTrOFmyfS1u_#+Q!VF*Ob098L)|Z^JUVad6zdV zL#U;|=~j=9!_5~8%AowNm^WS{z=#dnzN>%#-pfesU@P!dx zq(P!`m=QAPm$~Et7A9flIT}WWm5g+>mauap?8Gu-QljnBkRbJvnWHy|7#~*%9>XX4 zqwL<9UHFVUc>3ua29X8p&?W<+z>73K;Z?EDX2Klb@zNV18F7d%@f~3`0gB7bf8|1M zr3?-3?x2CmxE-x&+tNAr3cp)=O--Bz#F-q(S%+MbBFfLW&F9ts}|C2`#s zmA8j@t*kv<4)KZt-$#8m>8&bv>%`onbRwh-0pnS4Zf+l*h&Ge(fn^Ck^gZ=5H8Qc^ zG#cH(ZP1+Y|NN5i_1yzqM*e_&ywuT@jV@k&9A!OV?VHW=KF265*MAA0iLv<_nIZ<} zb8tPA~c!A#Tl0Ml~95nmxH3-u)({*l};Ex!eEHpP`yZs_Jq{N|C9%t-OF-6(W0RQ!}i-SgV7_wl7xIY_*sq}sq8uQS^U?`*O zB5j(Gia0sVQ=m5pR$vkXMEyL&r)>LV#=vD#R(_D;j%EjD@B#RqPNJDYX1Yr5C8O8! zR=pjSN5H`_P888nFp7=Pp=T?7Yo}?R9R|;x#EjT{{=qPCWiD<|%4ME33V1(o{E_zv z42dlTkAKmrAAS|18%E5NoMr_WLZ*;~(^5oUKE^bF_}%cp&vS~GPpnhwdWHp>Ef@N=$Oa6Rm)-$M|?GgB9@`;%bi%+nYDo6Dv2w8SlK!G#A`DmcHl5pSBi@ z71Q6_4VWFcav_@nN4wuweq{dW-j<`7!L$L>#N{)ZK#e5meSv&ngZKAVq+aHz-@lmjN2R;E>_i=`7O= z{Lae`_E(2lUI8PmkrCEpr&*U|+FDwWKYsk_?7+NC!fZIJOt@YnO*Mi|CLQb} zrzUg+Ff|u||H4$z454Onb}Yb@J1}Sa-E(4;%A1aSGjNsf_$|^yJ4Y#p={8R18$I1> z%!$cjA>DDT$)(j8p0n+OnJ!!Se&ySC17^wSh!7lr{lEJ~xyrh} z22)x5_Lma6so->miii{VFrO7A9ztGeGt9V@iyNO!3*35S`80Vp^a?IL|1w=taxs;b z89}Gr`LVIFr4Voa26VF59bmx{7sB7}&Liwv%uZsb@OEpMc@IKu0;GB#@kImy_|Z{5 z%QwOO!AZD$*tS0%A?~o2e=ztaaJ7l_8UgSP!s-YpX3_-Ui~%XmhiV$!T%Xu%%z`F8f6_)K)H2vP3rLtVG<2lTOY@jX{vY{Wjkzb?W#Cw+VDk;9WJ9@z3EM9-iE3S% zaUd#$V+AQuA2}!ZRq35_|HH?<*gl!I$i>yE-3%JJ2~kiQE&q1z4$z^{^fa+UXg~Q zF}qZEs`RQwZ-wA>-004?KY#v>DVRy}(Gm6+pWX68FlsJ&*S3oe7En3$;2iJ=2CK@O?Kkt=`y=jZ==M*%{bJ+1`**Y@I4COUOG?JgcXM^3REhswE z&))~+0Tu|gL2&Su4;y4zgLVzrneTYN)i>i$ism7_@#x*1G`*f>kKmAjqBoH4H0Su( z+u`3_^KwXEpa>%ISc@OyQ-WpA7WIsA4ykO=g{T7l4Nxy4d+MC28$^DLeVE>SiRN|LnB<5xCb1g2oeJ)i-gzNlIP2>kbspr7a_jBVIz zD*3xGYK8JdXz}c~s~6RMIgAo9BFaQqbC=aO?vBdrY|b4ygq>ij1#Wok?CUp!nt-l#K`fg4*^@o|LOH4 zidIP(DN;it;M=aOwK+V*WEPXw87xYF^o4_G9O~n9J89h^7Lk_x03YEcsfNdoi9?r$rTp-kz|7 zW&V``&)fTNhCP8toyl0XbMf7-IClO(Kl^Uh&fhRWNuqkIgPwl3phj!FWFh&8I?7qK};o)`#9lwe3rP znDq-sTzT~q<77@v4O?94;l#e3@e5N>`J^uX424ti?R*pfvl z3yB7@WADkjqIyT;uCmz69xbLzSyzbLLRtliF6G-qPEIroaw8sUjVH!Ki}gg!B#_cS zh;w_$+>?_$6RkjrGGSne02PjdnHb6lXv8gS(2A|@aamr7dNImwnRFespudXy2UJQ4 zMZl{vg3OU1wyX4#1#9E@4kK?!x^ZK&cvVB@klYLEqW9H~f%Aa@z(i)t%O^>vix~=6 za4`1ua4N(I0i5$NSSHU2m+R$27DYZ!EwLmm65^4irDcHq&7kM$epV_CDCu$Y`+_Bl zCJoU+EZBjP>2b|`Z=G?Gg*C(??jmOmWtBg9<2fd0Pb@e%lQT^Ou0;@!0R(pN_tK9A zzqJpmd-(97+j&VHYCW}3Em-^BR2hccZHEG;$k=Pb+guyabZO+vU_S(-DO^7ve;8H~ zkowS4zl2$vHf<7NVgebdz!727Fl@T3)pw(U2w|e@CK^vDo=SSbfuiE3ibWT@dZ?C` z2|Qsri4oBIND=}*=swBqc?f}Jho<)&={nSCK!TRrv$=HozVN~93F0ITFwZyX8`f{N zLGvS5-Tly)JY_RFl9^0gidq&F08Z`t)RG;<1ULO!=30`;p`qp_)a4=TBnn*`_61jV zR(<9^S{dZrZ^(9)sa9cc20kAm+SA`YO#~D3h|)( z*Pm-^zk5Ls?!>oBfcIbjqBiVz<#GP+zx)4d8Q#0I(Bc(Gx;Pvdsb-(m9(SsntpZ*F z*9r)BNbVGSH><;&xu6mlP;-B?F>>_3uU-DzK#98k`>)*xjBFWZ`S91*Jd`Ethri%I zmd^a(V*TN-FQ~{uAO3>>SlI4=@hCm9R&hyjTOF0t<}klw{C|C2r2DBQmq#g$#{VdY z`h4lv|L^a1kAM8v9JM6Zu_oOwe5OoB|DTto;aA%7Gn3NTAbtco$7(-_nQatEoc!a# z2_HVJcjwoyUdak;skqhDDhGcU(r)=AWPZN+UpxYS{QvK2`fz>!t2ZQnIQcou@}#sJLg?r&6MLjpcHTbVbLEqa$E$@itCbT4TQFc zE8D8myZY+6SEw2gioPLnMG7teM9JF7TWAw$7Xa2GQ z!tX_#c16@%dVQP{&=H_$oTv;RLFx(sVd6ilqtLyi0tK*exSXpcd=xsRlGq6;kl_Um zzH#|hO8|+tY~6|?lL53uVV5mHN$0bZQN0;ufIM%ubm`+%EhaMlc@$O$5w;+*n4Yfh zjZEb`b$>Pm!(Zr^q(ZLZ-}t@jX@f!DfrLP zZp}LgFETJovpl{iyEx^j@O0%5MQj*#sf_bTujNOFl1&h@nB~Lyb-kR9Otdgdk=+)!8GRU41Y=`wE_I(-3TUubdhw*W;)^G*m%94_V|omr{Qgm%o=iwn2|Y(k znu{K)t*w0{u@<;T4o+d@qXIdUvq#EaA&b(UuAW&RKbDVW4t0c^n_DA@XV)fID8Ql` z`nIYtoL=`V57D@V*d~lSIakkkSo{Dnox^z$ze9TW^pbaA={N>1L=uYhV-1=F$>dsu zYmvw2dqUVFZHIlXW^r6iCDOo@k|kYc$CS{g!;ie|=2 z1iw6>)&>dO&6{oWW5HDV{;;?E$Q4*6HeK5FT|+nxaWsV%z!fDy4G~HX0M1;=1hMqt zuZ9}!xVneou@r`zphfUl7VLB@6cPl}6)#bKhB1HSO}bYV6Bu!us?xj?_drsUI#EI4 zhqT(J`6{?l!E0+nNjV^_JFB!C7mQNtd`LCnRdV$3VbO{Hp8QOPmOv7Y27O=hMn=I= ztb_j%Fze{>^FoJ*=YG*R!e+*d8DcyZLP1c*us}9DdbHMy6NvIGfVUuW*OFzCQ|pWD zgwf5gXCyR&HOs^bXv%lSKc~ewlIjYAptbJ7-{N~nFNZ0GoKsoOaG-6NLy(Bg3^${k(Z`g-6*W5pXi|ml(eBxI)=Q5v1 z6gUTjF6Il9!`j0bX2Kzv>>7&RGGmkb{NDq-AcWaYBEEoyKRGOla3?HemQFT%SKZ-j zYsjxsYZL0fXR>I+!_Rw>YQI5h;0_anX~l?h}q2=L7++&tY7l{XnImI(~%uJl7T0|yQ)<27Zq zpEF*DP>KW!Aiox-Fth(b(CmvJ{+@#9!(6ylGQa{;a{w@gmW-wWJJe88_q zKQ}uo+z>#0o+oLkI1zh)?H^zoOWC;Vi8~OF6sb6B{F2~|o#+&!-?wh9(DbvG5NXj>KgGZC zlADtMDQ|UQpbFw#h#;nfC9$IwioZ`k;<0=adcAPK=fcqwD^_{u25mOg*ViK(A#CRZOULoDt#4 z&=BSzi@={&u9wlq%N+GVtUz&n-rMF_a||bfO!k82q5S+4Xyyfxb7DKN@<#bMIKI;v zJU9tobU9#1wPODIyc_1rL9|;3X7~7{Rc}|-t~kgCB|v}m8oTjy>0vh!MIoTB`2B0k z>Cl-)I6zb9aU5hxB0~B@jZWkGpl!>1bGC5B0`ss#4wY9ZuN*&c-lTL`$H2LRHXlha zkw?a{rRAfnYuwFObLhy=46C`7z>0*bF5Mn&epzG#cvgE8ie9^jz%hSqsi4w5{Z2Ib z1>Z3E+Sv+*lSC7O_DXPbkhxISu>}hSlwZ`AS@Z#HOfRF4mUdAF&PtRDOa%7?+As_} zTMp4tPA}nH1wNFF759h`a+_mDGqdTgG5680$HX0Z^ciG|QS=^pg8ZTl#t0%+O9$Qn zs<*HaatxU+ogo<@Y`^b5+~jiNi&%TXZsj5MCBdT$lZo|#^b&;JPLjD6P%XgtzR26M zoE1)TyYuCk15q|(^_3o80->#xK#Q_fWdi1 zT%T{eC*Q&8Tn_w6^XJdiQ9%&Sg=|@RSWiTD#o!=_9stYALPpp%ogKd4ij;y)>>pWC z{WLWV$%tf7Rel^2860ZlEl2Q76>HUf&k{x<1;D@aSc7>N@Fp@S{CWDROkRgGjCL&TGF}NCBp0LOqx~g+<(02wQV6a}##SsC7t;=iPB zAVkl$et468U4!MrqSyFWIIl?Uh4>jRo4O64i)+dpaP3U62<*#CuP zH)x+G69BYe)x)W72Z%*VBHU-rE#*u`WYMiVf2KJtui95kBk7pY9Z%>xYgl;ndztH= zB8mwumFeY^Y#<(4PEgyozftoDp$tIeGi0;kN-V&wK2QUr|jaP z?qv9-54x{_4f{a$Dc?;m_4$Ipwaqc)6*6~Wf8MX~n-`{L6oioOoc{(}FZZXUY`C^H5QWOIb zs?+tW88c_j(H`|6-3_f_D2tIa{E? zLp0J;cbiP^q@eNq>$0)MdD(8z5s=h*3YP7mWwKhK((|Yvkpm-y&PCmz^a&I(B0d+m z&q$1;VCkdKA|li*Q(nP|jBip@gt#kk8JLD=bzLIU)6M)nf}^+F=+%3P!d>WQmi{su z9l8wSqkIc7Po~0!go9p(lDv9JTw7?03bjNvJk(&o@Q~sJxiOs5Rtq(8e#L>AH3da{ zrv)|Z%^;(5J^3DEzr(K!ucDF4sJXQDDV)Y;elx!EIlUnP#@UTY7a4HT=->Z&vDDKl z4qG${P)5S}S4OPDRl0p{2<4$&s~C@qlw=?Z_Xu5Q0|Me6c=bYCsPvaM*n(#%#x5SF zGNoiVNG++4@{E3Wls3?i0RwKrliJRBqtV7$HPrDIPk(vT9Ve%!9cmq!iw1_z;|FB! zfp!m1v=1xa>3AQY%PNrz4%(2b#@gE}cJ8W|T_L`nRrjBBBFA(Sc@ z!P-7c&C*6CA*Vm2k;)Cq#6ScmfI)<5&smT~%`HrOX2KOVpIy}))Rf1q-K?22?{P%K zt_y2ws<$U6wU@w8p@fX-NO)A&-&qmsQ0-^yQwi!{E{y)F$lq$57)D9;T~V`Mg0^w;7^{C z=E4itUfWwb-PBLuz*39>-0C5;YZ%YMFeEG>lA-0ZaS7olg9=b+go%k0?JYlX_(&O2 z^H$chKN!Ubp|fX!xkF}26H#u_R!33IA6~a^G^}}UixaV*nME~4xwemT_o@x*K(PuJ zUM2z<&r-17Z!s<4>zi4-6t zFyZ}^0Aw(S3h&>PQF$oJmuaQ|T-oYAyJYU1V33ul!)4ImF(K%%*5A2R-C*|4L97Z8Zz{^ z-o2mBZxiY1{-3j7{1!idz{vmfzEky$+pBMlR}VaX$947V?dF}ci|*!?T(v4OpFTXH zM88V$>R*inv(@Qw-6r&k*%n4fkQ&A?yQ8?J(UP*p4)bzJ>Z_9S$^$(@J~H3G+*%%% z$62U#oWd3XD>kPB9L=6SIE$wD8_o^y{ZF(^6DmE#;0C{WE%~6*F9t}uP5HN6cW3s( zI$tI5?yj|H%f#?2q1){UNejOJ{>RcTmWRB(ZAu4zCqom8Ynon>o}HiUmaX=x&29zZ zX@-Y1aX}clN-a~@?MII2_1gB6NqBsJaF`oTpM=)|MTz=9#<(~2TLwD(e){R-M7t@= ziy9}n74-8AjW;{)=U0_ICn>7T85@e4@?t$MT5bHUmiv?SzyGm`ss6-};@2I9`W4DM zoc?M^h+U^8b-u4Yaho_%xoS3xM*t=nn7n2)895~=0 zhXASB%U_nLGQ{R*RrRZg+n)N<L>kiW^oyWhV9uDZ&Atku>m%w9V z6Um{rJ@;*bb?U@?>^J>YRZa-=61x&+jb>>+j@u^QWd6k;HFRzHd`U6;Z8l-QDO6_$ z|5+FwZ<0P@{P;XB9jB);Jx(n&O&B!(=3jN9&=U7#XZPfqqMMMif`s=<&#~uhRSU@a ze|*zzk+~QB3UkL{AR+b%HJ-cLeblr^wQ54+ON+$Vn;Y0uZnv|ePB;D3;F2W$sx9Pl z_XJr)vwc*@@orQEqSH*YPdKQ`mi;5jCio6oRHL55C`fSa&x02Ilk1qD&<`z^ZzjB3 zdi*Qjqnb+s@OCyrMOW}2!=oUG`;gzJ_i_RMxA;@z$Z(mIu_rZhqcM@#s;-d`?YWZ?j$Fxe2Itj2&sIw z#ls~15Kmi{3)WR^c6p%WcA)Y@w&#O2!q_}HTk+XF+U|lkQR@GTd-KVbC4(?6&z)0i zB+uP^(`a7Gk%I@9>h3h7dHhhWIepvhn?J1D)Y{i=$qTgOol&^SeKhi;g?CBj#Vf|Z z4S5;E>GJm7XAPpd_FF0syBzI4tPND={HVEh{S zTK~zH1NZoh*zx`2S=X!Wzw#N;(>^B3@DoJ@n0pWW-z0H|YsSbcj5GXbnpXU!sb=w| z*GB=6zPVq1pQmT_Vpzm+CGBwO ze99FfR=1&5-_}<<$2@u|Ra0q~0k2DJ3L$c4d^_te=i2=|;Gtv3zWeX5PyFWcZq;(` z&q;|dC>G=$+wPjoaXK2CU;c~5JhI1)p{TkHOE{E|t!c)$+}o{tJ)M6u&7{SNWH3PdBy7&P-#lP@PD0M{1yBQ&-GXP z)m1qk`0cmv3g*RMZmgqqZt-?2{j{mRx5h&G4=?_>LAC6Lcb%~P-KqnNUUw!vCn_y) zdv~K3R8jWE?M01odg^V#7W16i8~~g6v8dNnDa&(SHpL8B_TIOO{P_2Y6=3|nEWjg2 zkD5In-Sdz#053HKwDR}!FcU>GV<}s@_>W2mEYZxi*yO{d0(;JftTX%etIHu_akDVP z<{`;k1~k@)(k?wW@m#w%?})m1x`@*~*)AsJU=hZ;fdMKXb_WqC3FM2wsYGq?jn#LzI zyo{qps*$jFEZ6%|X)R;H|y%AEUg& z4Vko?tezUjZ=c5nYH|G;BTqPtRi6+QSQad}!~VzZLMno`&Vqjyl<#`}X@IxgsDe*( zP>%K(X8c`VqMKXHEdWysc8le#RxJVDJix8FiPqqG-UDrnG4z@;Mw&bDM(Pdkt%p4b z_^=YyE>dtsjS9CLCshhb?BNGUUxmt%U;P%sF$}Omr#3SaJL$z|Btl4ja;&e zD;=0Ia{Ow>(iA@3lRfHFXKXI(`?%-TREg4ae+lpKh2Cgz7~$2Sr}^-ui=W1STvnqV zgcbwOd=)GH&7RXscHKBdLxL}hq!9H>Ug9Oa3z_0tG2?yow%_z7o>)4%qvD#c=~mK- z6^)1en84FO>9HwqneNUMm_5#q*2f&*|KwXCz+h=*czZ(*sc4n)7K8bq{(f=0cxKPr z|C*E5AGukmd;u%W@58%^iCY~-yi+B8Q`^9QlDUH(XUy4AHGj3A!)1APx4Zm3_{=E& z7TiFbJ^7|iTsExXdro^D;(2VYVQcDdb+=z2A@}y+qd+bx`PE@|2S5Ad6H%hDG?L_t zKV~p}N>AfrS8o6)+Hm@#bfA<{kk4z^;2v@~@dBU6f26m0_AFh+?)zSMCmw1ct zE3bE}x7}4bc-33!uTf41thfyNx7aYgf=lbFI`ZOeKhL8v7vhs8;VBA z>n`nbbYt|xSO8Es2sI9>X2RuXP~41xl)11dNJgcyzrp*!t|vGNs8``q zIBdFrA?eyTy8t67(41+<-@VrFCHkjsI={rX+ggGTJj;{`tHv-sUs?z` zybq_f_03N{Zuml(HqCO{xFX$5dk|-UbQC~%Gv+1Vw2Kkv+d+eo1*ouc z_RkY?A89m)z)|T2(A-F1r2?go@he-?#8}!o0n%6eAM)Nis>-wL`rXDPCPq!M7c_~Q zsMrfl5MzzKfTBpTqKG2W1O*{QO)MA{1w_FHh$yJ2NKr^ouz)Bah=8C7ND&YOK?J_v z+Cbjt{k}8K8Rv{|oN>lD{P8@7viH63`?{{R)|_+A<$iCe2gV?g7+#+LHCfU4TXRz6 zcaHoL^W9Z-lNNJZDH)a;7~35#8q`;&h4FBL*wVvTGR zjOCBSKIdWd7YYhALK&6Z7qa)jKV+V)x{9AZ7M*?*v_|RZWUuY_WBV8zT(O{g!e+T3 z@0_hnh7}*0hO6~|vm?M}Umwln!tq};EhenLg^^#@`1wJ?y4ay6jh(wWj>}IVG-L-&{~8r~$-wx?11= z%<~wd;qB3a43r5Y0;p8 z%KnNZ-|{WG=gZ-Lo^7w+s0>~U&y&!JAK3PHZ?Eln=Oi0y(7*`BnX-hV7JARAJxWp_ zQd;kBLrwwrRH&vhb&_H!tEGb zeAMklSJ^?Z_mFcNwC=4D`OoUR-!#10?=n{9wkmO1tR96@aa%Zr3t_dKCXl8qh&lzs zD7q!iGY0j|M(-pPD1%MoY`b_IzOtehF`KR?eXEuO)_p44md5H3m!gQ0U~c;#ltQw0 zq&22!kjW^*D`)sK>ur}W9@kuJ=GVfd)6C0gxpi8*<2epx zsGi)$qCrzcP?yK+dijOITFHkAkJp9P#avFZIHry zSWv1Q$*QP4BQA#A24qNF%)#xiZ$>gad6uPv<5OwUBJs!J_MUAQy!d3FNqd2q&BVVs z$2)Ky-yNOZQOqg$BV)VAip|lU@86)kxc}VB(@mU2gl(li zpvLP(>}4@dW!J9unl@7SE+h)r8_wb`7r)SHB1$iP;qzt3j*nT1)ztX))jBfu1>&-p z1NYi3$7$21iPJo@ZbF@c582PCuR?usHkYU~vMKBvS$n=>H01Ijub<2ngv{KRf4$(S zg*Z3UQvGPj&fQ{V43}a`bIujA4t`EtdyB<=ZT9ImGsi%nplvY*{T+GaB+FeMiu4fv zxi5S7d}IS6Zx>PAqv=!Km~D#~EHuHRXW1*%0Aj5hwdwHp)^Z&R#q|fx27EHv3hm(D ze9>zG#fvG|s-8m_OmkVp=WwgNRTUlK(k6eYa|Q<`Y2KJNT7N9t9Ee*0QBiPR>LcqH z_BVK$W%zCR94tE-t}a`ciuY0UtAtx z-X-!YC@cdjE2|wBdZRqVUREXpZ_inDU4xMIUM5fimm5(*hdjJWbUY_Uv@Uh`O^3`G zvo}Xvb0JfgK7Z%LZlhy1n2%pX^?mlD-Gx)+=OU<=Z)BqLe+JbPIn*=|&f%7J0{llehgBr()11&TK9EC4+Rq5ai4fK9*HfRcyHUUFNpN zYnhf+#X9|vOj2?vJ0$QpA9Cek`g)<-Xio-JE@rZtSS_ZfQFNvK_T_}L+%6gYEgezm zY3Bw-HsX+;?pmi6M&uTQ8eFq#&JI$5~Bjc)^H-tc<24K8FWH>{dMp#%cPfc>YV zwuL<+ey(8uv*bgQ+-j?Z>Aqs=s4@0Y8cg6xL*pa%N^OoAw4iE4bU&=s9&akE16Iq} zaP7W{_tw-On%AhdD#20GZ^k_Y$?R>mKwt0iu=McHspz??k(`|4CHi`2rhfvBFP%o* zDhcAu-B{5#b6yAfe<(FWIOJb3%RHTNhQl*+g2a1P`ti5s`)opP;lt=oX}6oFQ|%5) z$7P4$Je(|(n7PsGkPS@L>gF*Wfbj^JWe;ox z#B%rJ!aa_)h>fAKh!iC{eA)T=c!7Ljhie_<_r6DION_z$c zMgCM}ou}CUMW;?_T!(bp&p;IH1F;|-k;gSQPeLB5KXOHS_ zi@Wc4lXptX^T9K$!a=I4vl#CoIz|Ep!()F$iH=$wJZRjpXJMkgA%Z`m`A&6)mQ(FT z-p4jKrj2;#Gt)$zKt;jKjBJMbAaoNVvQxVz3zNl=!UFLCA^lL)R8~h?D}osj6$9s@ zA9T#n@jZ6TC{Vd?I3i@Y2@iUdN;8V3FhLtBW0A!W6S3+=#_!t68)MXte!-(=$VRnH zzlO_n?F{B}Sth5yo~(%?ZxQEXGTqw)X6Rz=>v!WG@j-R``1)j%HF%(4Tx-o!`~LfX zthVOB)QinEmb^ca%jiKu-DFtmYow9X@(a0z41HT2mMM3UhbgWb_)y_VIovq(J|fvb zv0X<_EAau2e^z<}0y&T1$k^xCOCti#oY6n`!*ps%QTU~yIwbQUXn4DQTs!vW0?BZ5 z?LZ7gECthTrZv;=8y*e8@+_F4|7m?Kma%w)shm!J?lZ!K>20#frIS{q7J1DF%!ywS z|48>#bHR0ZE+0`4PL323&_I9}nO6WtVCz0}AUBvQll}rSRcl=Cg`VaZi?QtzQ~2i4 z(B@avO(U$5DOUZ_Dm+yTd3f8|v8M2s%4Qx2N2>h@*ky`hYu!nG`wLP8sq&RgvyEtrn#8? zukR@6TWT)%{_}fBrZn%n|NP#pQ%Z>9KfmQVu=R8Q*LPzM4zh0{l>YtZ@)^w$?%!|l zkBb9#H8<1T8^O*JU7t#9Ikchwooz`EmVFT@niq|Q3HristIleoJ8K~>N{oFMt zv}@(2{#-z=9*lZ`hsoBvOrM|liFkW9Lw$-`FZ9gC%vK72yOmHWQ)2#}BGkyy3KHGUhJZK?HHOuq`m|G@vxI=Nn4tRSwZ4C@htdLqXiKC1 ze4uuEvF9WRW2$zmIy4T8n9$a0XUl$@Z?4hut(M=(?fIa+&xlhc zV_aX%Z#hqfbcHFDZ?(Q@e%^X$xkuGKU0b*4+28LitoX1@-u+AIjHd20*{ZXp%!uQ+ zck6Uc7lR|vfv?_o&^Kw`idGp&R}UXPyhiI&g~OEX+iwb9A!UP6)k9u9MD8?0i@auq zgl)|9r}*sA{8EGNGA2!_0pRa$d5u~7wy80vDTT#*f83YGzQ5dS0gr|=O3@=N#B>v{ zzoeJ3Z=2fo&9^3I3ms8Y309#+JrubM0C>f}b9w=vF*oPmUF1o5q^(AYT>qA|uNa_A zW-s5*_G)Iuxf?EB5anMAXKPCt`bzxo#nVZK7LejpKlz|YcoEP%x#vW3)vez`z=Xwk zoMkW=Q7Yh4&AP8!ENc~t#+O9vzFt2L9XcB!nKISb@MNS2mHStHbg@(OrC;9w%ysX< z1DPqNy%JV6f>Q2&b#2o1>l!k;80+qZvTjvWUjO}j^8qcaX1_CQepOGa z+g%Mj-4whc5+nMY9>1QnHbTAa;odFB`-z2h^p3Dx{ppI29P3k-*t&k)bCC7g(mhmy z`e69JfNf)3_p}~Uhj*&(cwoEf-z6JrG}zXkUpV;hYTS|B{&>-|786>9>lXLL%8Lqa z=D!t++8a6DwNh?057*%>%j=y_j>-0JdCoh z@d5KEwS=CEWzq^R%}uEAl&*ArT7NxNOpQV4i#QTwh7GwZIH7bmi?GhE-Iwokni>O2 zOE*B08IW#{Q)%bAza^Z{H>Fm%gDiQc2-K206HY|H;veIJuO^@hh|Aj3I3! zFq4weV$~bOXP8BdFgS3ft}jYDVc4qsGy;gul}d~K@Zsa;`{X)PjDWaFS408%!jl-c zO(!1j2a&4ca2I}S6eVeEf+DeuY}OAModw9}KsA-R3sBI-{hDfbH3HR1JDwCk=ir;R z4a*TYgqJEI0U78P{wA*XjcXH_Pj26mW}KB~@S~Vvu`u=5pR}1;3oIqT^ujMtB~=`e zI0Y`|(Fktn>GrYmK76=^J=!TziBqv`4t!YmPHavYb^YpGP;fGcMCltR>F?lG7Oy`! zMiHa+s`-FBJ60^`;lKp_m}b+ZTQ{+!*joF1m&j$9ymg|I3{IcU3+|45#Uu^oXcz8} z*1+&WsZ)ZZDrZX3Mn_)0&j{iNssa(n-_z-?IPv0P^Nig2gGG}_nK#bll~Zx*VCi^( zQTcXcj2y z+9-y!Z`mNxj5%8>L}UaHMw}r|FBKUrVgn0OziBqjFU%9~$j(l8X{t93Nfc}L5^w|S z63-7GoJ2ApR2M5E-*2@-v8iR5$(W=4nMo}-irPVr^M8k&DV|-#m;VM# zpQTD?Dlu1AMVMPSer`OF^5iGb+~*)|#f@pvk|nRwv=j=9X6{OnyG5pvQB4;!RdfHN zJ~GV!*|vCC^-gTv>aObII_>xC3-C0-|8+l9{3iXHRJg_Bc(H1~jz85$K>S)Wvj~0t z8!Q+ZMuAE3rCD`t6lv3%PluRIqW=QkOyfYeMYDQ(sg#xgbK)-WAGc`^mKY*OI_73U z8Yj~?TM7y3+d@k|1PHfC8>oL$pl>v^Vlh@dh@M|oYT{|0xPlKVPqbcU z%t###xR%T=LW~dqz9b_rpfZMtu&3X`zPwEc4x zux22^d_AeQ--GIMm!@TyI1fFj&LRoRw)e)r=G9gQh3mYQgWf5D+x`|WG56X!S%Xk` z44N^Ofg|G34r-t}Vnj%) zmfwSAm?_ET#(*=7N1=1%9U`_=9zw9Vut@h$${PU2$83Ti!m9XS3=gRS=+jiO!PZc; z5p(0Sna?cEZ`JsC9@esiY@v0mbWOR~WN>`iM|`S?dG};abM5Ys#_4^U&IHjhJ)|fm z9~=^}>Mb@4&X>Cz@**6@MQCzVd-gw1i zBf~Frr=JHJ1)10Sx$!c#mej=F{Mlts`hzKGkGt)m-A|CetF4jIb@yiat);(LoHAuH z$oR7BlR<#&3||eMq1*8GkhDG-Z@%3%f^G+9VEt*mhNba{up}2Gw7B1-wKHT0g|xk~ zM+LZ%1TAG1BkJ4_9C&V<#ff2bZ+TeeIC_Tz^5`#K@KOhUkUogW)WMOEL6X(X@=_=Y zTB^hC1;LSJ$v}C4yTPNAF3E5X8BYV0>E+hMVhbYAQ|5Rwifwe_W>m-0NQO>ZhNQhp z9>s!6hm>Ma-TW8gT4(_PLMz=6z2OEF{V2qK22dX8X3q8-J$kgXqK4$RZXW$(=;F+2 zd!aej_nxg2wU-d492MdE6OE3wOnOOyCOquzz06+|6hvTL>E8ibSG6nUM6<=HWTDIm zpA4sF=h#5gyp~ERh?=rstp}`rBm*5TUcC6tesPZ!rbmFn{MiacXiIvftLO85VGikI zngE=9qDZ?e>2U>;5p8vUO=-x0y}tM9 z1u+#q)bGJNB(7S=1Ndy0GmS^+$rLyAjx$vNCVGl7(CzJMQI0m!zc}?-SjGlP4&)Yc zPS+H+iP_bX9u9g$BcGVrNE3kHn-(u3Thze&3#5M@Y^)&JKX8Eb<oPoplkmW-i+XX%0u5<(Cd~xRrLMaZLJxwVTnk zN~et?W=zYYa))qJIu)gC=5|J~6a|Y5d8t81*v>Eel_gKR09!!{l)v#yJo`QQB zSZmsFyGkGg4Jg88QU>XMe)6;-t&4G>O;2(2YXu4ER_0@wSTP9a}PyW70i3tMsYWqnySX zb?Az}gER$!CnZbvyCiFqwIWW4QqBa0A9H-ZqdPnKBr5h^tdy8dLN#(cN4||eJfhhR zZ{lrf&DCGBWJ$3n@~(&vzsFscREtm4@O^pYWlM07ERzs6*4%gvRr^kG#?r8C? z0^AvK;{nB*XLGLJB-9_9J`i3~gbmgEK-IE(V=ET^2Wojy))8q2&mStR8w1XPJZmZGK-N*Ct@{YUJQ8yfO3q(tl zY0*}3zD06Xu?nsA(UB{DEkjAPQvFjhqjT;0I!q)4Z_!F8;LjXtznz&S^pl$h*cfp{ zi_2~*mErfka?ikdBhJ4%r-8-ER?0V!>Hz^It(#)`BvJtZW=XkG;QHSw+*VWi2+sE$ zgl%e(5fssoOHb-hJXjB&2ev9W0qK30I+Egzy+*D4oz&2@xQ{SoM9cFK^FJRH1e(Mu z0B&{BeZRfY`7;k(R?V3ObopDh(&vTxk1;|r%Lg7t|4)^##QlR%97ujRw=k0v<5Icr z_$?3oRu-X2lV-l#-Lz)OxT7m{%y>kMiYCRK%L8lvT0V&_17Z-uN{W0!6hgU~4AX~f zD%bYr!J5K%r5UFok#7~f;q*u&(cUx6_Q)B@H-$69oFpZj3S>)>FJt1}eSJ-fYst5W z*qX=Ry`wxl=0>F3yv;=6x~T=p{E#Y107=|yk2Y)+aL{d>LL!JUw?MC4H!Zv|oX`z338^z7m)uRXd z9({k+c)Q}4ws|9Ub)%n5|3`+G6@5TDBM9ZJr8_(HJcR;VLnVxY*`q)S26BZbPW`@f zrzCd6_lu4sWfX+T(3%F`SwqJjE7yVc9*sAI(<}0%9)Qi(OAeX0-8Sy|cX zxLrFBb4!J5c}r#XVLpaGJI=4VocrW(d{yi9%s((>y-Op8xm_>M-8{T3solpP4;nDw zOzI_$LOl?{OpEIaOZxz3)kJY~Tj0jplO#Fqmlg^t^_N}i)l_@}|JC)f+oGDn{{8!N z1d^;2)1{Ib{juuS9P z(%fKDa0jQkt(WZdi)!M})EdI%?GG_V&o!%wV!zb(14_7TlT~Qw=f*QD`HAucTR4)5 zUf~fl-jj_j^}i4DQTrGGiVp0*W`*c2T-d?C`OcPEt#O$^jmK&Fc6}!Ut?`>OuO+^A zr1hlOwgYAUe8<*=X^g?{QvjH_v6=SSEQ4;LW}!FKpu|szqS9lb@xro z|Lf?{V%yoI!E(q|_aT*5wM&qeHin8_*DmB>3P8$0x*MoBowRy*(&vhbjV-yxJwJG` zx$%rx@D^vFPtLS1rkEya%JKWPkn-r6`Ex5rk#!V#BDuS+okPlc+^q=l$$b`|Q@<9B z%esghH)Ry6;54H7qOz+p#g}qv9`3&*-Wjz%sM^s6HHBdm?m)Z2Sr($vSQWQ-!Q#b> z4fsmFF_?U~^GeueO zo$G|mL8C`sP7N3HN2x23(A<)vRw+5T(%FtB%K`o3=?ntsN3D0@Plk4J*GrS12yRX1 zeyT|CaQ*!zUVLhsJ{0OoIVdBlfV>_8pG7sEZ5a@&P&iEmGi5+)_kcWMPjnlK=Fw`E z!og`vfL|F;VzZ9JNW#uI<_J)bOeYAZh{r|o2WWiJCLb$skTjSnf#^U=F2DQYi!X#I z6>$J5#AdA|3@Y*0EVJ#Sh;h+q;j3zI3PH9{{n>nb`)eteyu7@KJMK7%VoWeuOn^8I zD0CJxE?g*P6b=y|Eo@zYxtYz5Dq}t3NNyHW5-$4r=Rc%YML8%1HtX(QV&Y?JLMgjO z=U>MOcN}~M9+@XoH`5!NildXri-e+<&R80IWFnYH+GdSDbSPVUjs8gC)V=i{2Q3Q= z772np^Wpst`TB+szv0y#NBs*^HwF=`1#K0aqjHqs&Umb9AC*=s^=#WiU>YUKwUVTR znOcU>X5NF*K;h2Zldy~y6rgrjWLd;4r9ieBsbu}?4^h_iICsQ9M0_W+naq;E?}eS_d)!m?kwBl;B!d>~FU>xk_9wlz(zvqx z?~lNC9_QKcLi?LzhHLzkY17uzNTTdD`i~c)o=X_2RA^VHgTyoDEP8dr^P(P zl+ALtq%@>^$vUJbqC!z2f}l?hFZ@VRJG|vQ?<-bWOLFj6LYn>#=siVlmEt7h~AK5+=c(VYnEkBufE#e_WY@%L7MBQ}V z_B@{(X9F4V&T^CyxwJalS~MhIw^qs4$+`U5=bvvr+JEfW)Dl3#<}bfJNWMrnW~%5N zDd_j0=8xZg`z_lil*k1dc6Fz61d#N;6<6KNMAB!ST9!;9yZ1FTRkB zkYF54xX33lg&9=s34kZwEz&yyOb|5sLdiL#JOSy?4oPNy@4;e9wsl&BCi1>Y4OA}T zog(5h?w3eB&Y14mPlU!ER>snyV445&3x#WUsQ|Gkx$$-@43{*mcs6tF4k;iV;xZsa zig0!$d4A-nm`mX@GmXV0HaqT~p05snp9A(RV1dm06|)FZI(D;=)#8W8V!Hk8W5tOj zEh~(iYoHiY=u{<)@h~Am&f#Dm)$+DHtW#Momj4F)mRqaPSF|jk)^G6t|4$?hP5riQ z>AS|JA%63Rap@h#B>dB3_1$M3_LW}l9A`OX*20YX9g_xET2y^;e7F0;R{@C!^;dtf zbYaGQ!#jrSdc^+GGxX>Rv$FF2Iiq$ZWR<^rkhL+PYO|Y7!S?9#^7<t+9c>U#%09#Ff65j1Yv-gRrVnU#whDTzt%hb^nWs zAosw~vf-^C?BD;lo)r2oW!?JC{Zsyb@x3mSfUTuB;j)WZKV=VWO_8j=*tDg4>&>vb zL65{b;&(;}O>$XF{&gAFT=l=&@cwVVb1z?)u|9pIT9Ek(gxS;ha14Qnq?#lnU5F>W zxN1{e58-qRv>SkX9V!1K>e5ppqaz_9HRLeVCuqTiY!sV5NDYrL^`?JwLdy{hPB`*- zec~GIROt6bnr?WNIEJs&I?{OMkx!SLyPL~&YKT7Qwx95-9Q=l z2HoLZ#==0zC?rjvuq8w9B_!CsI!Y81a7a$=sNTM@tX7ffeyt ziU5|<;59tTwZses?U|lXj$GvIaR4hAW=Up6Jizg8DS3M0772XQ&@he#Eh758eA>*W zxZaR+9Pu6BDD5EXKLD@Ce9)AAEvyCn7d=G6q8>B&R-qVlMVxksQGp^7zhZu9#p%$u zO(QlpRK$KM<7&wL$A~w}@UvYMxwFTPZ#(fjFwC*S+!43#*tm&dAOr?$KB0Q^)@c%z zaKWF>xSBgQDy%tiQhbbm7DQ6MCQcTv2hdbZC6$qRP-=^nTWDk+$dDd5nAlxLY)BMo zJJ{kiiyrrMjiVeEAiqc^7qET(bX>c|RQBydA$TWB0|2%amB&_5%8O7IL4JA6J%u88 z8plzWl@R4=RBICgXMN0 zc7F6-9XHrg<`YPJ|9A7Q@t;Jt_f09kNdiGdun=A2@MT*FtHO$~e)=3N2iIPV+_i4y zNaoXp!3`F}OxWM)9!T5mIRv;={YyKY?V!l*q`Z?-WC*FyDQP@twKLqff8OcSr^S@A zBJL$%O9{B%%E@$&YKY%dDn_Hgv-h(7low;!8M<#F#phe7?Qq+9+o}o@#W}Li%>rr- zDP}as{>T$HVmX}YZsPFwb4Optf}VmcJ*Sj@FkSAFh-_0K8IB)up(#QpMyxO*Q^U6H z1kWiT^%98F4LIVtcR9z}dyvS`%xxrk7_z6x`^v|c^@G%RK>p^Wr^AiaIYPt)^v$ZEM7tlr#JXv`qU{rRs1M+)gk2#gKFq-j~ zRwq-`$D2ClcZFQlu;U7>wvC5x5Kl<@xbC#xi2nVY-yT47ae>9?5u(GE5SJ43y3os) z-@ZIFZwQBixdlrvGi6MYc3$fCozYE=%OrtMscfq_7JGmFRl{uKk5Nj@VAlA4 z1V@Wl{*jm6$F5RpS<+t_;;d%@{56;2@WGob`cnjEK&|}G$o<$lNpr7on2ehf`>?CX z3n=%V=&PWsNRa{Nw>h1&Wa<$)qsvX7DVJ!vlIbtMENY|3?ZLy)hqbLKfl-mk;rqI_ zQOwuCXKJbl;=w>Y3j1{1GcZZdi)CU58j0X3apX1Zg@s z?M9!JgHJJ$en%1-0ol-cm*j||e~o^ZA$XaxsAdYR42{q+YiG6XZ~D$&;YyQLKoF6aKp=^QtZ5pjKzoM`}&wuFv0SzM$H zib$oWcKhPXK2-g2 zeVr@TTc@582bRhKWOOC(jzvs5Qu3uc6m7;_aNI3WZvG+LtYTuXQhD(-*jS{;RTkm2guNh zG4tqb6$Q#p|27lb{1iIIH8Yt7_ZIjdf9rM{N0~#ruBJrN7|z(IL{}j&1y2E*m&|~e zukSy*C(C6X4DEcLyNpu`NOguwBYe*kmWGpAHVd-tvkq@w&|g0vI7(`JQK`d6UJbwN zc?eEH!noQh*Z_Z4Mj(fXS{|K+vZ-P_J`!S|efnWeJXFmY-n|4-xLSzto#Qa1^&WL_ zEZ|RUd;<;oMtt0n37;q?uGq|Fz^V;RuM%#JG$z!PW&jHi5(#kIn;NzN7Yzmhz^IlP zAZNe|B_$HcIu9cv&u#p%h3_bDuQ26HrZ-`LlyuXy^39tCfZ{B*GZHR|&+Bt;Hq3eF zWM&j;KSySdCF=Ix_DHFIIx$ivuESc{Xm^Kag8R+3-2f0JoT%hUGSUN4lL(;XUNRBo z@CZ{Aq0>0SWq>FU-@WJK)~;D&P9iBDn?n1^931WeVtNBU6d2GMG#~<%(DS$ey2oZ? z@ZCfbJ85`|{aeK>|7wOK;a`f-6J(DN*XKHOI%cqy zAo6zfNrHiud;npoluSTs18Yd; z+uSc(i|w=ljTI_1 zhdk!2e(StDEG(>A8Z*fWQdU$qlHsP200!YYUZlmGEpuliV&};<>@Uc``4o%V_wRef z4PAQq)(RI%Q5CsGv~?UTsXxuYf~B+xu1~yk=OO@L4zbNRvgYEYJ-=cOUPNM}P3srkKomsvEK`i;C;&zm_{mJD@#uLI$Fe`krh zdGamWo6NZCQ4qwV^I89iBG+Ez3f@ij)22>cdLgZ;e9yb$$M=Js-vFbO&?k9;mY{&x zVq=;luYxJrb9OVjNJff9mU-~t)_4j&<6ng&?&(Nhod%NQcb7j?v_HwBJ_iyAmvjiu zXL{rhgIAlGt;cd*05KDX=bx-uz4~Ri8_7-(0$#z4fEZS%Uxm`!it31=U~`m~pTQZp ztgc8P?yry738sB-<)wyscgfZ^IS$C-|2o>}f`&^$5U54RkhE|p*Sc)EKqu~P)1j-< zd(Y=B;zjP+4KbKY;i1Hr;DDt^tV2J|wfC{%lr~2>P{3(MJB_}zPlo0I56k*IsaT#p#>)th+Ab=8V&k`^#)cAiioblFJV|f|RsK z6f@!qT>EjX3q}+|G{~q4?$gnTIOaZ}-~ju(x*?7ck8?!}B~!Bv=9nQ(dpt7EC+Tx8 z!;h!pJxYvDSsb*8P_sx> zkG56yd@6BEx8GFIs7L=U8`2-)OXh~C3GiMtQ*9`?yksc!`p6shoM;@fmScV(N{cPg zoOt`vCo4Q7c{Tr%zs^9-^p4wr3?TTJLea$x?M?*Dxmo%T#kxx`EEcVm9Aj7|m)>lo zT$PgG%IdoZopcJ+>6+pg)5nm>-N(nW_$caM=}*-3sj9A)kyvA49<|ezNXL%HGPH}> zT%bCg{I##kht0wZ*qeGI7?ArQ-Yc9ym5Zj+fF#PU!PvU&X2Re@JS`Y8>YG<5d9D;3 zRGl|Jz3@^#4~oIjdqdZ{+++@kc${>|qYqyZ6D9p=oYZyq4xrZt8J>wd$la92_q4K; z5Oo&YHXaH@r?l@%S@C3VJ^gbcrtlucE?b0VVNOUkOIfDRn$70QZo2FzH#tFWQduPO zH=gM0hyqqwYII}u2>Ki}&$~(aB#uK6DQ<=P@hkM~KXhn<)km?qqTv`>;Hwvun7VRo zV4bhbz}-Tl#+Us@m_VOMw%@B=JS|+8xaU ztwyhuW&MF+Q;%3%VoNa7`ayjn3!>-t;m^yBt0YnYWs8^_tiAogcP;{iPze7uA+)TPHG7*%^UP zm<_QG51&q%Ki?3q5}WM#Iyi|wuLerLylu`HZ%?Z9`;B#U-4mg&^z4Vegh;z9eKTCd zuA8G+N(F+4`uFn^$r3q3;)>pZ>Jp|CJ3j6$s!Y)X4?WTpZFiHqaTVIZTW$LB&GVG3 z^vF$O@KF=3wC%ncr1`Y3*cYI%H6Kz#^myPibK|8YZ$Y{x7yS#lgZCTr48&6#q~5=PNk7w?~*VleD>p&;aG~Nz^G*PN%9Ieu$rpdRuN+(DM_c51*7b+ z5fk~ap7%;Zo=isL?vq`Ojk?BxU>{<@Q5)+CFGeh3x0XJe-4XDgr@b)_587hV1p9od zXty!YgR$`yCR8QT&q%g*Y;in46?RGQ61m7D8HC!V?WgNVW_Iq6Qqn`$)5sP@lNEcN zldc%xXN#c|z+ID>V}emAB= z&(XI6+ujm;f7VI1^#>ZyM@FS3=)~)K^}HPC++NY8gFNitWG)05%hjHs&`E$=&jRgM z({YJJJ2JnP%B*7u(EF+E5=2yYM;34fXLNiot1K zZtN^`-x|i@eg4G3Ia_W#6b+j}D_ki6n`kl=?JI)Sgz67`&f{t)E2fd(L0Mf6Tmmgb z&25NiZOi&O>egw6AIr$oKFh61kGWfS%d)^t$XpNl0!qH2giA+;T z1~bk}x=dRO2x78iL-`4N;E~M1I@DG@@X_%O;4@4_CFCP#COFmM;l_sD_DtoDzGq28 zY>y`8PaGLf2_mymUOEhiWI!hB7HBB_@H%g9brz|M40ORDP3qAcr>cEikC{WXEy7o5 zq5|^W&*~EhOe$qDa(`ZpHu4dz*{87T&?^Z)=#}(&z(t~aR`JF+yEmFOJ#^VkRByN8t-h2&PyJ+*uN&MxN3Nn-XJI))%Yr* zvU$#{&iOF#sbgR5B^(YR+(+{10Z1!>R?J4Rc9KjW%gLyJD_KfjW=G z2!F|j4nL|V9UXUHA5ej>7)=>JDEXac*=s74#exoibV(G7)p_IV(x@44RtUzbXL`Px zXikXUKFbXdR+RBd;{$O4prqlC@L~>1zx9PfnP|q7{J=_b;UXEUKuZPC73y9skLPxLi~Y%9~&$ekO*7 zhBYPZi&%@w_v`G(q#JKCvNltpdfhonI+vIC+jILA9u8?0mm+uz_3GzuI32B>k|$#~ z0Vul3!5R3SdGRln^*X*O6)1%x_`yh7YIL-1?4gWGR#B4Jym@hZ=NwhqG!yJrR_(LS z^vxlt1-{!(A#^j*jzO1Vi6hz7&TuM^-myKlj11M?Wag_Al0IrLTcZB=-+zBaEthol zQC$RU)SNK8vyO*#IPVSMoLnt3hqmw-{MD^%isiS&W6eLcYTWXHFUs0fq@LjUtz$fe*ucJICtM|Y14)%8v3ZsFIY`$P86 zf?Q*?*CVd$;DuA98)`+CC>OBbj@2+2Q>FQn$-6>x__dffGIfGd|CxR%P_bnX<}G>W z{hmk7p{lBSzu+o;^rU1*0}W*iXso9|E_fqkge1QGTEoVn_v92sOcmk$#6wOV@t`QG z1rG=kv}ZGDbXXtE=ng%%1ALv2xl*7ZJhDGewPzeC#*g)D=b4YTFYxa211`$B4(zmX zDhQDm03hUyEaP>;`i!xOMwU9BlQCU?7*x8Xr7N-k4d>+ZC-ySoN%ZgaG5{sp_#4HE zzs?dwq3)M!Jxkc`Fn3$k!SHw9g54U=LIvC5-6;@Q!O&AQjd%R;J?K;?&_+Ge+KAy0 zJ$&l0o;Jj(KG9B?%jotj=tp*f)@1f$Ybrmayz6&vRxe8$AL{$dh&63uUjOaTr<)?n zA7paV`(=NnOBlnG)+_-y5izl!x0|v=668kdH+hvpvsWCGKN4&w%-ET7%H~k!s`Zva zA_SzbDA3wg43mEG3r;*3z0wGOqyjrfX*w%z(8Lg1BCv%(%(lbm|J^xiELy+zp&(*e zSJobIcYhhT9yoOdO0xX?{54t=1`N10@`r#CY15`phrwtZ9aV#z#Qt?sJ5>F|G__Ua z(vnXfeNZ&Gw65H`|8$qR1k4r8<*l9UWl9rEMYIra(U_!yI1n1peW)=!pft*e}96i2k#k6TTzq_B=yLYd*Njz-J z``W19(m=u9za06)GVGW_UDT?2Y#Fov;{89prkyIuxjeRN?YbXKx0i4f$5AxDex^OP zYLoNaewKYTt#4ICkY7m8g8k!Fx7Ot;3C>)HnzE$&+S)tu@$qJ%rMmA523EEGNgT`~ zi$~>=1a?aDGJWvjSj^eFWh-5Pb+>ss?tY=`RTtLi`%Vwd3kmhuC`*E%efp}GaSe6M z^6l|#o<^weNkDAJdmR;4$D{T=_}DdbBC_9EX_^G|S!s9g?(6K85OwLU#i2_&-kxQa z5f4iG*FtUVMb1`w^M?_gPv9u^mKr6~`FQ`4e=D7@jFJ}BR|#st-~ajJ&79R-j_a@e zM=%-k`Gb1D-rnbrXC%dtF=q9J`<}m!co)9O9VO?%X-=pq=$WG=jz4-bhYSmt-*2Lf+&a0v7S+7Nk0x4&?C;MNs56UsvF+TGMm7R%0puaK4gUE zJ>vOn4ikhzQ1hiq8QXe z6bnDvJkHD6=b3S9tj+1Bd15oaHxNyVu^n7XT$vUCk65lpuIDNF{Jm&@p0XC{GTi5~ zZuwwRp^mFa?KTzI{L1k+W)zs!moEo%AItQPJuaR7WgCi#XCi^H+`?hFa};q&a0#MT zp9p(~7S<%0J?ryl;8GA9| z6AFHIA;B{EYy9knSb~B2mt81*{yH&prQCw+#$JwJE+z36EbX~PI^9OkHLLC?j64WP zmU%q=8L~}o2!Xe}urL7jf?qNSTsW|P?b_^s4Ce53+B`6vq#~udc}RHO8kkv_Pk306`^zu&Aq@~a@Z*uE$e2dsgWqo z$m?I8IwzN>_~f|PU}vx;0fpkV|H<}}d3->;K(+FFaa(Voq{(4VGxtAkVMq%=-B zo`S~xK6&S}7ZwMReHNr@L5iKEu`bpuf$UBi zHc~Y87)_(|%PBZxf18gKS-X^rnuLokiym=_fU`wc*|xVwnvpyCSApE!EBN@c8ABGI ze=^M5)2z6b2Qnv}m~n>Fv}YXgGtQQ`&NT#L8@pH8%zyZB{;$LnzQTT^cB-qd?;RJu z(&d>RYd1p)`3Wxx`rkXw%ltAaOfU%b3FT&x(ISC%4vvnRA115;k3l6G+A|OAq`2yH zg=@W({==6SY%U>?q*JWt{#DJEVk&&yGBVR3Pz6ohD-NbspQJLZkLsdx z?ujLhAFAf1A6x)q8*MPF$9OfY z?V*mUQ#jb~EaGn8?%C<{%>gJvd&QmS?!4V*HpaKJez4I}DrNPHBrG1@)YJiMdy4m|Jv67k2HdlCeN&%~}9(r?oxaA$7>HPRPt$MkO;g z${vyc^~aT-osn~rjrS-JsdwPb82d;c1S2{1DVIjM@>a$TX`%#gC<{djT@sdQWH9QMKFo+B9S_; zFq-Qt2V&AA`p{!Nsep9A)MU_Om6kXYdj8{lRhPNqDG*FOFe;3sMFxg1rCw%C(t)Sm>(5>hU&ok(bAUpid+9rBGY_=e& z*2t3k$MF(-9e_jZ+_dO=EXsn~Nf+fU%*fRZ84D>D!=Il{*7=2{B&;HjKip^k?%iU7 zW}O`!{fJjFGjQgVxXNfdF_ zhb%btyLilS2B5H#5j#S#3o=1*cQkmF)JXdeM>2R#6}RWZ)SI>&;kQH~VDwV(YZl^L zfP3%VIY?xYKowPJ!DNYby{Yu2Cu#r}P?iXE?r0d#Nh703tD|nOd<2xJq_B;c5Bre* z>q`Ay1EhO=;7s>BRn-#IxcvdrPAH?Jx<=#Rz%5L?i(G`S_7O5n4>n&!cPlo4JcvOX zu`=Ss=y!umZ$=3tgH|eLWDVDnPJ--Pc$HP@UTDzWbM+|4LiE{oRRZ{MUKZ$D%YU8yOeN$@GiRP@a-4R7LOh=VuUmt(`Pl;cr7%X_UrbF_C3gV zV|Ktco~8{|t06AyXZ%os(q3KpK=Xbw-EeDq7@Ya}YYE)u`_j?D7@{)T-F-xv`v)%K zjks_+(}h`CQ}ELixh_S#Biz}z?H?Srcsw5^03}Dg)-V?~RW(slYZxPxtfJs7U(q#T zwWX&73%zznMFa#wZD+V$;c6pcb2tr^~}xp#j2l*`a_hv*Z&qMOd1o3{+2<1 zue>B$`3Rnc%8^h63hx10?Jy>PVrLBCI}XpgU20=xig4dIKXAamOeAh!rSX%*B{H&$ zTw%<{e^y?4#cA^x$I0vof8ElXjq@vwE>=q>n3|f34hX!{YKbxt63q@K`dM3}Uy{#) zV=ZRsZC)bnfTBAf+q2^rT}vv4_f=$_*_3lr0s!5ZeU7HDymXeV+W&`M_3%C53e(y+Cw&{ANh`irDYc$UmjO2E<~Jx_{z32Cf)4mf`qf*i&51#nSU z(B6sE1x9Px&6_=!qOlRb4N>gzJ^^6Jdg+2VNl_vSJhTRp7S}1Y#5o@?p`}9i%__4X5&23*9}4{xe)Ed-2qB4NTRA?E`SqBW2ttyK zy9&UNdb`E#<;D#9S&T&l{*^XbK4jJ5AR?C-2*6SbI|5St^ruA;52_ogG(KaPTzXWDZN1xfUgKm%7x3~6x6XX~d)sDh|n&ky|Td$K{5$SnfOyD7O> zsjo_qYNWDeS8km{f^#pY$2>*Fd@2}yGWtZ%sM#BOfR6#_Ta1yC);|AjW#vNfan_AE zUUO%paUxk%O3(}NrrFkWIGxGShjAljarYPU&J;{$=@Cs5*`!rfNiWa?BMcU2sln{o z33E0RqS`ZM2^<}Dn`R03y+8V83XwO3)`(nXOd(Qn;GJ$NzNZX=Qtg1oS&B_aMu#5i!Cza`C!wX@z zmJRJDydylm7~e?DU=PT+!?LRz%~4vAz?hBsVMMg4^5LCDzc%~fc|22O#n(ix^wYT|i07(+D@HG^G7AkoC*chO)0704 zlX7MHQ+2kDw=8|q#`Tm&GDXggXKhSw;;%vQkmroPO^@tEJOaz~KU1A z4tGQ<^fEtxAO{wf#@B5dxbH(4+l7FN1$rbbY}?HrS6wx zX@9n-0a&i!rxb#($_^q2evF<&sfZgmP%Ld48ykgfA>$&;D=4eXZ6gwN^fVkSOGVIy zoF@y6l~VsBnGKB|hL^$y>O-H#PoDk?i8YxTeC9my48OR$+pyxYgOLJWF$`Q=p2Ql#^dgY&BFD3uVfqvF0N^e%l>sxFnkK2d8{#vNIli37()-ip z>#v_zZtt!L>rO~j#0&#rlR3z*v%>W#+zh~X@#zaBXR0Xp3dx3saE4;TC(1|U?28%u z#b!&rwsD#NR#|ik?R$7Vo#b&B;jeKH`(XJoXD{h{^5D)lIB$;$p@>{$J_hchKh56u zkha6Jjfy8bsJaz@o$lS+Cq$KOTGowNF+(&tZgt`&xsK0zQsKl1&1l|w0`hXW} z`fvt+9jvbHtFuk5ehE~r7`v}g|7=usPcFJkdT%Ugqdc9@lctMAiQ7sJ{i9%kmnJDb zX(O`CZw)lZY3KYAm;})G?cKoVb<<`yn?E}2KGuutS$O6%J>NEpiN74wi>0- zzC%nmxoNI*rgcn8Z+@$f|NSRdWO=QZ&!$_ce5UzdJz4K?^J`w<y>*jd<{zm9pmj`9|rc*tc_7vx{)_w80q<&)L^7~1;!4x?NC8DmFftWfAxCRaGA z$aT2}+?~aBam`kdGX4cK37>|oockgG?ybr#nmD&~WAioe%a_8MKVbV2N00xaJ#)s4 zaP!KJiY}iDAx!`bLZVxs`lX)lFTYv<-S7VMPjV1N?pLfuS$RzFJszUXd(EFOu=EM6 ze54f3HvUv`U7scqITLa6E7WW!udorZAuD)`@)aTt9K$j+>iJ($vbYsLFJwALto5h{HsSSUY@AqYf|!u9LI)!2TFS(ycHBhF*|5m8eW(L%dLpT%SdVbu}Qh4Svn;3?o@ zEqEXw8?t_yizi~YQ~c=kF7nRr_Uve9&L`E-BOS=_Z8!hbF(xkk$R5KU#Z{^y*xWMDISVL1vY2>UGgPBAZUja}+1 zC?u<8kurcBnE|Cirt`2Aye3@({@k}u#8P%t=#e37wxVdhjnNN!leeYmDBGV;{*GJq zr7if77*E4P*qh#eYsOAA7agp$k1|oM%);wghb|_w!Ic#|E0bR(D1Rfcbfc}=)@Hie zxb@!^t83J>#@y)HvfI29l>fzv^08U(Ef{Hr}=C=O!qSk)_&z?HgtNDG}pKJbCj38nP z*JW8xny-8QL*q@h6X@q4&$=Zagk>Bb@$6L}vL8}|(PHJ3yg$GXh8YWM3tNlGb*VkxAP8ySY?0Q z%cbmk-p?O=%ShziJQHhVnd< zUz+%xEq2g&dasSPd~)%-l(sB3y&=Hq&?OU$AXQ*UzCkj zD2D5k?jjP3i0lwE7?n&A-7|8IS)D}l5lHS*HoWbiw^8sZkF%#L9@L;0E@>_z6^iak z4=dWy`2&sHdh>MrfNviv-*BC@9BDgTlzpSPxQuZHGO{?5E$#XR*uWehEnB%r87;=; z@NZsD3q>jt*S@@5m#s>~{ofnV{145Opc()D%XPXS9RK~xb$X@!mNu@D7?A{_zvL4r zautH_ARFO=vM{?{LTXbsJnvd}@Kfk7Z#TqbGnsX2dG4&$AA9y76s?>SU+eO$4tAzX z=gu}Mm%g96`z(?@exc$zT@jqba=0s8Q`}L-QjZVXjs` znO`)gYu8TmwBP(b0nwtjo7nlve|ZlX(nqvl9OzlL=i%CkY1VpEcfGw{2C8A{MY-lI}SCE9S`O9^gLu6#@OSzUv?E$-!xcztwf8G3}J3SBm-+jl{*Hrvp zT-pE2zb8GoJ-qPjppnC!j*Bfhix84!EjdsndNOn;X_T;(wv1UAI2gnoVqcKiUw{2o z*~~&UfUkH2b{#KEN=nKue-Q=ZC;nXG1B_K#7e=Eo+RM-1{~U;-2+ffy=;KK#bp+!P zq;o&pxMt&XFhlhQU1ma@mqz@6fVEA!rq?zsm0zPgW9&k(GUIP~1Aqs40~8gT>XZUv zrrjG*mV`F>hR_TUdH>V-3etBAGUE)F9HR)^8GxWd1T@zFfX}C-;^bkXYy$R)j4b_V z1P%jqtfX_v67R^iWOxgI#0w`i+Jr0(^@lYVJfhKOivnO0e7gut8Wo!hQ`_~sroW$1 zvFn`ewo;Xm>-X+{();Y7BRvj12^naxYVyDu``3RCy!~X~s7Fs$?KygO$xk2u{i%JI zTU$GiGdXSRt`@CZxx@F91EUK|b!>c+3_l#W`h8Wxc@O7;QpdIBKWE+8)*V&k&$DMU zQLNX7jT;wSS@-Dtw%!H?29(b2($dmSowhJBGU@|q>U!tSo%7qgX@g2%|12XzGpOMH zVbveC)W`}Bh(>WINc!rAcT|Y8*P9GiQ|rbDIk;L-`%br=@njOUvEd+?*Px zPdT@s%P0KeQk%o{k}R!S)2Fbbs-~u9NyJ*Ws!p9cS=rjYys@v%nlz^k8-`LVF1NJo zv~1b3w5YEy7W{c<{*%(#$v8&DvFJh@$3uNA?-lD%Q`!>;V0E~few>K4IAnSUxR_qI zT0!O4g+|Gr$jcXiJ+~eB*cW7L1v!m|eQ6KQpMCD`?%ECYRL%dSq;v+W+8Gw6?yYZ) zFtaDM&TVv2x#i4ZYWgr{UemLma22W)s9$~cl`CsHzbGg3>C;oNG3zp~4MSplyL)lvw?l_^ zASflAW6IE5?S^e@*RFNgy7dKz+1ZV=Ug^xLnLdC1S8RPNfNexPZH)?(QRET;re=gyYZQ{ zl-AL{(&fh=1JWz=-^If)GW;|AGLR$37-LhrNP z_))+hDSie^A8;QPl|>ioNb!b%5-vo1GS&B#J2KS6lp%vHEiKQbIZ3w~S$#fG*Qt36 z>ZpKB?~@DW{OLAz^+=FOmkR|3=a+n^b>tuv*?BBc0#OoM>aGH6_3OO};K5xKazZqn zOTPhpagT9g?yREEbabaKaEJDZK2nZo?<Ka0CoE|*) z4xb!q2Gjcnq-7_)x=dYw)ao#t@?gr1g;xvsKo|D*R-YTbrQGM0o<5{Sh2c%mXRdOJ zmcYdxaYv~L|8$!&<)AsOf3xy#!S_{u7!TU^?&*z0n&UeDrB=?0bEjy9L)RZ%(Xg6b zbr|R2!RaBU-JK1SXU(0vSkPXm{iT<#T^s(;F6`R1#X?DLz}|Vr+_~MH*C#`G9EG{+ z;zv?+-myQY{X)T?9ky*#J#ys8ClO`r+2qGgEDw|WJ^S<-s@3hl-wY6ab@iw&zI3ma z>^?Q<9#qa@$e}^JziNesDJAq8D9f49aIaDhnvktsrr7mEmA#vl^GooylaqG1v%!Y& z-uKdPgaw8pzZq?rrQrv9`blsdXT@NS2KN_B-Am=Rp)*zGyjtK_5DCO4YJ11V586tK8T<#)p%Kt&zd&l*>|NsB? zUKth2C@M3hL^drE%3c*E6heto#>r@jl+i1*a5Ri$YglQJgd!s&DqG6v2*2z7bjIA%gUJ zDIXXmYbwbRK(HGOoYHS#cx>3w!U0s#rqS&AOnXJ$H(BLp+a9hp?jU^lHx^(*k_V$= zxZ=Z1S?OGaLCyw!xgqtj){G<%PUbt)w7u)^a~Cdjryr|UP}!5;l5ZQ54Id&R;2U_b zZ0eKi`~8(v7#R@q3hQF+=;&B5j*pdDnRRK4boR~5ohmivR=&(?HJ7xw-j5D5zz@s+lw+UF_r-Y%f{-rBMq_v9EggqW7Cng8y{1pR zZ=CBg&(PSg-@qSLRo#I5^?A_$DNMeudU{r61G2#=SZ3o~C(U~NA$_0daY&qs+!U9b zNc)8-wAOr%kG@XZS^{a?w1HrY@Gwv!JXD9v)2 zr{x{#Z?!DD$D@9Ht)aEu61?xOLEeTGKGw3u=3Q(Wr65c32Pk|wC#NYL3!}0`XHY9^ zaRalFUO2brck1vh+xPAbrG^|sO%OJ?Im17Ei*6FFXNZwko;=aSuNi;uw-bqjD=6j4 zu?3ybEZLv2jY8j0Ow%S7_%pU$y}>}F+aV7ZKtcV={rlfG&u~-;8$v%S21~ljXhBPa ze)au3^yt-V^S;F2rIcwb37QO`qYPSI{a}BFwT(?W32F#&`)rT{xamlF5}y0g{rhSp zRM${TYO*?+4oegDJ~aRV$N?!$Ru77tTTl(^v79TmubfIw_?^nsCu)#$<&j1wbX>J+RbE(tm7$MfU;VX|#V=jGS{umF75WfOymr%@cGfJVJ5+CA=wu*> zb7;c>ATVUIXP@|($h^~s?K_?Ve9xKZT7Yw|Hz0J!%yUa>GVjB%zDtPr{55rxN+ix# z)c;gn<#BV*Ks5@_*TKVms(d}cU!yc#@taEG)Cx15GMEgU)Tff1lA_YOweyK$jX^70 z$fdBC#96aX6wfzbwoLB64kO}Jl{?3fMo{8N=?~x6H*={>4QIGEl&55j zw!-1tTHGScE3q*0;Ga9Q?$NAb&&}cU7HJJrJYTn@)U*znfzDLcj%29z>ej6xCFhNQ z{;7-ka=GX9C*ww8f@AtR5a)*YENuIob$~YT)C-=mdF~_h2`<@~PGCHAz8xZ3RsrsA z#^{CoO7D?W!rtCOOH2LB12bn|oTs#rK9%bX8uVChMxY7J|5E*qzH9M6e2B9AE=tRf zA3t6KOgE;(jCyH5zqZHKpYVKb=bv`!yQp|caEGhtcRnkH#qCRN1*&RuiA!QS3`{C% zZ#3Vr@?%SsJk3ttzENGxNBhjaQU@eWN37((NDQP}iM zg6p(pz_d9>I9DSMOVBZh-Jl9~k+R9-4sf>HBag^us^osff|~p0|%0LpFb8 zQoJe?Fz?ili)=7Q$Z$(mtT4&4Z?EOk7gJq%YDi?9UZYfwTfyGd69n&V4KU0jzy&V- zSTuO{fRNXK3FcD)5=ox}0{}Jy!00P+!~AAW#O#dUMiyO%K`T4;w6w6OC(Z6tuKA6G zHGM+kI*pB*#54)?mgk2OPeSUbLu*G>;UCbX6syA1WKg7BU&3hB11A+kAk#2sc2D9( z8AA;5`Vzv_VN@nR(Ku_&pGYMst@o6Zmi&+F?GC5Vb5mj5m@4$j%Ca!-DJax?@l%=; zA;;g=pq9Sw;>Cd&EIp(8lF>AnzHeXsU%zMO)7RH_&c6#-(w6{9=i7@T(Rx!)ItOmu zU?mKyru4F|n39VFA)?w10(W+Mae0Hxw5d@PAtEgE5gd{GI8%9BNJuf-tOo3Xy*c7W_xq$_pyihq%54N_}R-VRCKi# zIVSZTHSR`uEK)$~rk;Aja54_@n&;}^hOA(oI=i-@+3X5C9eXYAk@@Rrom zFZ_BDhJzrybHbS3x#w$T)qtAh!5>E}B!?m5v`%{SOg0P1rziIZg>^Ia`^yLu-ZkgU ztlWJ=GZ47+99o3_B35>!BJdP_E|KoQQJpN)tM>)#SotQuv7EhX)M^1mtND-587zK$ zuCr;7R5zOtq+ijob=baZR{*5ux2Yw8%)0%dS5@9M9;B@qim&UG?)#n1a2qbKz<{G; z)Mw6_qeUQmj^y0=^P|d}3`(l`1$w7ke$1FARC@dg%^R<22CH}@{|m-u z2k}_9dtN^QJ7Yn+w^y%Rfht-X>cbc=zv}bLh6o7svdI~@cTiheB!-1;UImu@c9NNi z504eE1z-0BG_^ci`-*3mYfDB((7+AF?>!?=!VU$8epV3H%>ngy#l1n(=PZh{ho^f0 zIO<`V?e^hmTJv$^#>FgN;53E;7s1-5+xa?Aow3B<$um5sRhS*l%*`G5rrbqT4mI4R z?(SK3`||upKTFbyEuI$_Yujj_P!R1ud-ZAqi#?DVlCug&>%y$gavJiduU{=Bl+rs+ z9M{>`&(A!g6D!=X-O*dQwSOB~c)l&LiTjaF(!XF|`y;PdW1l{MzCvbfX6cq|&_O@@m>oRxY$(ch#^C*eC-E|~G}y)1 zh|kCHwbp#qmDiJ*Z_&Pe57R~}4X?S(+qPF#cS0NU%fvf6Um7(97+}r^B-hNdohX_u z!eX{dJDN)LPH`k4^gtYf!If|CH>c*Kruk~K=!cOgj81A9J6*O3y6Sas8qqPydE_Yu zuJ`ov3<8d<#Y=CX*FJi1Q8nq9I$ED)>Ty(vETPeF-L_53G`l@&IW2)Owb2|oiBO{q zM4e}=2l_tWKXpcjO+P0)!*zg9r5+MC`o3Sz?(RA3LTk#YEi|?7OW{RAnV*>_J~zj_ zE!d=5TDopcYR7CJhsg8VFR5gAG7WkIP6XXsHyC39q1%qKkE798TlrUIY?)r%j1(VZ zWN-p7gsjQ%X-euOzyE>pB9a=tKuwJ5q}ohWhH+I$uCu>IlQ0GJZ=w^4Di zd3=+u_&2QOUPiYqC!>)?b~e2}1G>37*S-YYw<{lK)zpI+#mA2`=KS(>nYvlTI1o%D zxrMPoQNr}ojR;CwYmRIqbr>P2<;T@^KMBl64$bmxlHWw5e+$-DYt7=LJi`{jqm>N* z2?ow+LIp`{I)2SOBXc{Z;_{6fJ>JeEvW(Ui*Ryf{Qnj75RG%m14hc8C7_Fy1#J4{X z#%`uTu4M7$P&%z1sfWD=a-)bW&OD01ignuCZF~;0R>Vc!S5pY;1jtA9!aR!F;1vLl zac$B9_1wu?|G*mrhK9wkN~;;gjr;cPV;%$9M7$dE6TM8WL+OnKGX}16#u}ex;@0Dn zUVi&0) zz?uhgHMXN`sRgE5Fu<`J{eh>By6B@Oqdna|IXA*@afk1JhKVYG-Lb_#;%?S!IBiO! zoAEe0ZL&7ckeL!nBPp$Ay0JT~ti1Rs`ZhOL{Q2jfF$m$Fqx;DKmVvHzDH#rj#Q^We zme%Vi|C>{kyG=RkH_6qtyWBi@mRB3|;cV#^9Xj+jwLV59eD(hQDV|Lx)u~(exe}k} zyTT>s>A9sNJtJz}$jlr@25wyOx!ofr zg3*i_hLgFC_RpYIU&nYx%%moBT@J^H-FY!7iR0UHj%J_5ye7b>_iUn-?GLRe8bfSl zW79^n>wWnyj>G23mbs--?x~lF4cJ-Xl>3~}JBW5_oe4zZHegSi`b$wkr9%WfIe427F>c5B$ek-ZK&;P*$`q%G|S zl$PduKxml0Y-VB(vy9B0WJU?-Ya5> zBeyrJa%r?w_|1M-d=>BV>scQha=6>0Dj$#_Xtrp zsMWl~pcb0`0U-f7x4H~IY9tE$P01-(A~Prf2zwy#{V@B89XslQV90s3I*&{PEfIZF zi@lVqMung63)$Wg42L3K^JvNHJTeJU^)^{VVs1Q>I;N3jiBHitVdt9^;gWefEe zS@l?mOXU3W^2Cv(VvPGBIh+37$QNTM-QIuPyRVGd`9s@XN^xc{SW&GR3I=o&6D4Cy zvT*;MwHMy_Q+c@xvq2Uo`ie3deA*1;VwrlCAuc*6RgW&+!9s~VXk1+<3}KC(W!<ejFC_MzTq?tV>KveP?3Bg!ogG0gmSJ?K%yVUB+%Z*z{*O!yOxg8P^ph2c0Qz zSK9jcn0{PD3~hDx)a`e>!@|~4^$;AJXN2#B)^!PEEjYlYOwaoyTrUm!VxA$j1}uBo zfa*|JrtjdNe>#EZRrnnKBq0yM&-@P`y+{fb_I2vjz)`HnH3ea;2Uv|7l~*0y5TGSh zrToY9$a-L)MV~)*)N|^gZa$57tN9^u6RrJn zI_g4;U4LOF`*YU3j9EROcJ6@fHi1!OPbpNTl^b$zUjP5wQP z;@!&Kw|>yHvy1nXsI9BHcK`lPHpU96tze;|osP^Q4T;vqex0C9qq{08i9D5@kBgs& zir~`j4@qBU`R3g?xV?QA)8+dme+pkKf#R%~7C3yI{^S}DV@XP{m0N^h&R$Tb)rYY5!eSV6 zmlN+-Ed`Fr_i(lcBDXJYW1e?&kmcg}3u^9V^0knD#P{KhRLCY;VMk7qVIGSgaF{?z zhJ41Eg_rr9S;qKul82>~)A9_^-faB1=ELbB&S)x~lL(MGYuNCzMEp#Mwlcx}>Qai; ze}OgOEnKNBvIep_p{a}MA91JbMpJ&d;A&WZYDKg}mFVzb`fdAF_A9DWub%6i@dVv&vauPhC|nM#IgBZ=PiDnxFf*oHwr<4#2GF>=Re;Y7!Wd=donA6 zu_FGH1|Ke$rJaohFtZ%LeX`ed;`#eF;R8P9)*3W@!I;@Ot-=hxS5>`!v#Jn%h$SQw zyxoT4WOYhbQOA6m*Q#CHIO%N@nsBF-=Cv?t4=1YWnf~!cRK~7dyT&Stp(fjzzB9KJ zd+NFV>>Z~Z$eiuj`rN#XB~GhOl8Hy?MwT@&t85f%GDn^c-Q^af60E{I4l8P~ll{vY z#67yOya(w*Z&QmJxO&x-~sZTG#)x#g3* z@BWqQx43rWyZaS0KwqC*{uYblo~}^SHZ^TWmw_pGL;Fv8$^I2bt!dv?K%7DL?Ji}&nsxqxKFLzP~ z*CRqkRP z{0fqeIWzzVZpG08esuBCO}Aj7VDKH71}sxW-~% zNA&+1ON%ZeGps}V#dq)AIk@J-C%oL~-yuVWO0+J4vca;08_d6d2D)gz^J(+IvboE0 z#;D9K&d;hHsWTC{Yybnn>-vBC_DxN*v*pRs8^?(Uo{>&QJ0;srow}g746T8)2Td99_Z~vvWh{&09p%?x&>w4w%3S}a+ zF2J6ys{YR#9jnT}mpko771uLFW9e=+QNhDSFzL%#32H>x42O!Wdh5N%Y@V#T{|kQZw4uWKNeV zw^b9Z!Gkr~^$bYZ-j9o%7wd#>-n@PFNs}~|_d)kDor#dL8%IX%@gAyg!5PSvO-EOC z{c$7M=a_e7u1)UK*&|1e4BNF!kLCNbV`RpaKekmBm0%;!ftM^(ueTKXZERHFw2C={ z$XGe>CE7{d3~|!Rc-L08mz5^jBguoAC~Hu!;_2^SYLt9cHMM6z9kmm?;O zXkt;K(3N-S!8?4Z-WGW?*pI`=te+T7Xu)8~@mTDzE*`U5N8Hr6-h{TrnFTssyVfO? zR>kc8YW{K2wFb?atv;%2uFs$y+8ghj6WrP3%pWKsrjxr#VW;*SDNvQT%u7jqA2F(+ zBHVcAbmt*MBa>pE11aViM}NFO>tm_q!LB~PS4DiOcFsAZ8ksw;WMSRy)HaD5F;~3? zJFFP?yVMeAaoAigb?j388u*52?c}hVh4+>Kb<9&KDigb{tsZi7c~_5q#+fszRY^c| z-^1E)i!mJ&tI?ywzZNF9b58kCL)OV=2|ne;eJko_{@6MMPKIY4|7~BYTbmDuTgwDt zj7I)H9&K8#K2Qy*LCfD)wKu3-y_&~zp6>^NrAP#>v&W3fg%UUI)y}`;S=_`3WGKc^BAF)%PM;5r0Qo{-kbzE|CZ!tF6+>-CF$P6>ysbG`r z8nH-a!MdW~s%^}}yasKT`TOaczba>*v-2}tb+nUnu<B@3SLe8y|soqBp1F_0e zIoZ{%_i1?h)`cstD%31_oV*`ZDhmROes`EVc<3~bSo4C4;*`z(dH&#{+$}uHm|!k& zcu8*8w*%hzsi;rO9arEpK0KjRPX?gi?O12m${__cK8Kq4quN1m z?iFN~*@NsAUANi+MN`ynkqDCll&K$|of91MO)<{K#vCOAEK;3NljyRa7jF18Kk)5s zvrk19t*0Fy$a=C}o%!(H!baa_o*Z)U{*Tk8%Xw%weO{vMXLZ0S9FX!+6G}7wdz(m8 zaWqUmXr~i$@3#X~yHB)YA%&M~%spD)ILjy0u)tq&Gbd;E+__z`1jBz9Z-#jjQM|e6 z_`snK>CaEAUA=zY!m|z=Q7fMZ*u8geZ___615qQGIVu}frHOSxg_FDlPm|LsvR%p+ zV;GILv>Cswep0g_S??rZl~{jS(9@#!ynp|m8SG!be*J&u;K9R(SMJ=g^9;Fvq9mPJ zovxvwaj@;9z%hG@GTPJ3?3U3!#bL#?+}ze576AldA?-*oKA@ViuOovJsVbLz_s)z< zPW#y7%dc711#fRD)HX5c*Bu)bCntDpXLmW~1YgCN^itNN3tZ;4Z{PmOajVIYGOVNSS364k+Ey0hgBD#jcDtt9O)98d!fu$uO~vMVlsq%Ro0*ow!MN$>iTv$jW|m{73Q*2jt|BhgqVLU-keaX%WP-*tOh%# z2da=g-u3Ogg#X8205R#s1TnJ+#H(mt-j79L4YN)&83Cu;Kc?2s99UHL^I^q>;Hz%> z8x!vD3Am>$!Q9Nbb?n$Nj8w!wck*3rIB$ag-~Ui9GP93swt$9eq1kx`3p{i8(fzEP z=kMP2#4gb7HTz7`Ib6^PTZ%H~Yh@2)ZhA~SQUk;qR_A(lb|U~QB&9QOqsrUt+rTgK zzATQROJsp$37PMXBDcLyX9vD~_U>I1juYd1htU>a3GYMxn6U4sYuNJLE0!IvuJVM%i__a7KapGp6x z<(}d`*V@cm*DWdOrvA6s1N4XP+*=47caCHh3_kJv;Ua$*lfBj za6A~ywsO-i6>YW-jB3YrVI7Y@Hgf<^X;E@v+W>xd)LYB0{y%SZ7QzB=g<4XZn>nvE zv;a4vwm4Eomfn-+dXqaIIX8TtXliJTPh0J$yoeh;o_FuxKc^XFy76+RX;DPaX)WO; zw${Ar=$yJ{K7#%%IBmC-+QVSBE+70{&m)Hxp31B3Sh^Z;#kJYphpDdq|NF8c<@$s$LoOYZ~SEOcHLDrS$p-^?lxH*F?^ z`Q%t!uOBO>@+3gSNAQ1m7=ds^kBKxSL(2R8yubd&y`>d!!M4CH2nsE?NbokG4(m^j z6BIhT1;zLEoP>a3b{@t{yv}<34miYY^iTpr5197t+n2_hIZ>Y-5HsF-{8)p8kJ!l4=csqeUbV0t%ilkgeZTc*vIL_iw^VRYz@1l(Y}f)Li0c1| z6{%vhcV6uEwX%}zbOY%QHw(WopjHi*r!EHT6k>*+>ot&jv3RICpfSnKh`(HP3*EHS zDt+wDrXRYWI|XFl>lGvJ7FW0r`?b{mxUs{L@tvhtK}2lWx33e1$>sIO_6Dy*+50s< z9bfZvv}c`sPM99?*V%hO20DFdlvyP@2?Q|KTZ})xejP~6MyV=>+Rml?{7Hn%3l2zZ zN_FF)ea!)B&K=zc4H`5a^cxAX?vLg5*Mg1-4BdZ$u+A!FPW?~0W@kD~+46?N8%f;w ziQC4aVE*rz)Gf87(qgOJ0Ly4(>$CQvDMFCu9$~WXm;}Gt1OfxdD5aWE%#CvstjBPO zK%^#nK&tNEy&D@5W}o3*n0q5Vy)Mu^hW{7_@x%*Nx zjp2**CL2C=JEB(EaRcU}`+VrEth6uaHb!H9xJ(c}At@=T_msO`?$kaS5`6v@JNjpr zL{3!xaJ`D^Pc-c=Lh(sdw`V&!X>&c$eVt6N%$PQHYT1XbroaJarEVR?D<8Izg}bI= z?%IC0Vu7w65xcJ*hc0_dNQb@Y-7GjT0Y5N;A8p0_Y+PvbV-*HZ#! z*($*Q#87UQR32je@_f&Mh!4DeGjJdfW@X`Uq%i`rQmJPxZm0iv_44K6O!gR9Ij!g% z2XVMicDcn`F84a_wHN2?FiX~80}MC|l9{gno)k@DBR3!wn!RXIFG?z>?C=q0X5O7w zB~?fFA&q?pYs*`hTft$kkUAr)%R zibDeLQ7kywZE}M8Kb;g&sS2E=8#rcC;+%BPinhW<&q2`D5MLs@yfdYI}Gr^3B&rO~d8akf$Z!R`ajb_Q9S-CvQtBEG#Hc1+2h}x`1R+ zcvr2YxU4t*6s6oa0Owa|+%rvWh?Qpbxqh+OA>PEtLs6dYoJ z%aHJ8((urmcw}pt`uWe`cixs%t$?*^RG~51eIJJ1<`2#`=Wn30)N?B9rHENFN{sHO z(q?zRy=%9PqP&|F9^MKSR9x%2NstUZ+y{!Fh+4e*O^$7uV6u%{W&rA;e{+Japxa{O zjg+$9yy$!djF{)$i53T^pmS)(Aem+8^*OrPgx0~`oBr>2-`qE}BOZe>p3Y4@%Uc{k zvO-aqsCAdgHm-0Ss7tls?wIfa5)a#F7QJgyeRu7VO;v^*z7c}l0F5mVPtSW3i*^HI zvzTC$K*hHpogoNnbN62FZDYrdRb)qnBzVDn}!!T?Gwf4s4m7UgL z*vEA!O6hVn;aByIS93;}^q&sJIB7I@*=go|M2TWEG2$1)F) zvtPNB70#p_8zz>6n6<_9`CeYE7kV->Gjce_heO58Of+)0@^OaA`0Oql*~c6nLE(BK zfINlULnN2G1E`Z~zARu)&{e^(EE^B}I45^!%}Y}p{0^Nj@<*X)ztXL?wgQy*ckTQ0 zL4LCfd20&xU(r*Tf9#th$E@ig5W;WlIELei(V|v)kCZ*C?Jo21SP2|Qn*>ylu)-PdYCMlX+QyWRzy*=91b1^ z$3#$XV$-{plbc&GY8p}xpT?a^+9d{z0j5-JG|-`*00S~aWS^(MnU|-c?8nCpqas^R z8NGg9h^0M(8UfE&07N0@{e=eANq4U!XDl}$K9E*@IJr+%ObXJ31Ez~4ESQ-woasg{ zDSrHcTHu_tWz9rINiR{EN9cK<5fz1TR-nf1J>uxz3AkY}BgbpHGRkr_IlUwhpe z`+`|%s|WY&)yv#6YbW;L0NM=9^k`IfZ=ZIPmSXyqE4AA9TX=h6&T*QR9s^s=6>%n= z_TOtb((qHvs>1@1$D?Sjw0nBqF7EDIkP5JUvnpe^AY1a7*16zf*(*MSimj2;BxRfYYMiU_kF`-0GxW}$x&sxL7~u-$5BHIT$#Q!$(8^1C03t5IdRgk_ zC7X@1&pN2#?FGh0=k$zge)k;xfJFJv(IL~2GqL$N-`qI}fa)hsH^=kL1M9dkK0clu z{uuBc>DOc zmJE5fsF7Ol5L##p=4^O#vrX9PuWP^9Kx^=qBQNkToG{%Axw_s=JYb8<3u57~V&V_R z?sT9g9HI<5Y#qJ4ycE+NqtUj1kTh?L&WfGh`D_8BNh47{QRgf_Mn;w(6CVvpm)&Iz zfhCcKW=@)(jdKw4B!^<_>JTCXpe!E{YU>E$>@xTjqof~m@YN^SIX1`Y{EB0aqme6O zbZM5|HM-dHdoV&BL0}6n6&2Y7Jm}_-I*7vlp{(tf0*^Qfm!UICWgQDQ}>mq+ly4^Zgotz*|0 zv9p%kY#8qazD;}d=+O%wo@&TmZKY{0C&8ru4G6GJa>+S?NKDS+zGKZQqcPagJomJr z((2My33EPGzoGXI`DAb)=}I#@e)HtLCjCuKU(kh7%-H_T z#&Ako>PD)XuZ~jU-K6aD9Z-nVe=%*Gi@P;(jG1L;t(wZAr`%j4qQ{hGVDloD{TRTj zR-+85&T*fhMGa79;T)7EB_*GmE=3(^1i~#Ylh0)aBW?&Op|uF4gG&?-$Wk@{(E>P7 zNC4uoB8)-UI046OBGg394FMDTKs^94H!-{XFqMl2oGTWy@C0lPa7D1?~({hd!d|@Ip&#)_Ly=9muRFSHsUS zx75|))BLaDxY!U+m-fHr%js;aiatNV-Y&N~b`bq+S!F2(jKkQ|V{ z`=RcD?j)?4D=b6+FhQZ)gq=fUHLdrVG1*yUJc)pNVpA6am}pPSR1YRE2%>|cA~+kH zW9KJu3pG%ee}spJGxdED{TG}~T*Je>%T++EYL9-Uj$)0%npn?@1wbk+DIsbCs|T+( z909Y2K#js&+XS)=r6+iS_6ixplm&^HB-YG&tjiqsBDl=Mpt>HUtTI`+a3KiM0G{~; z-Ag?bq{{U4C4u4ey`IbOjv3y%_e^!*!B@mzTYBf9Ull|Pa`cysehd zjeGS6wSyQ?C!xVh{Wjg8Wp(COR@`9ykXYA14 zPS9$qsvD)gH$N2H;*Psg=7UOZ(75qP zxHRias9{~EiX?#M=li}~mq=LJ?$Fm-r^cP9OhGl~HXHK5fRMR_APt%}HOF6XD1qX2 ziP&C@)044|Qfo^IoE)`h-I~e6XP$O#gY(u(UtvKIOcdEbG&z5+=tkVQdx&2^k+~KUG5z>x7n%FiIX02^)iN@gnEhm4px$qS zCy5n$+9#_E1-jTdC-=sM9|;Pu?&|>(S$nvn%jPEi%kO>LjZDJA@Q{bf2G;oH-R~qv z6owzE!sKXo>0+jsban8cK`q*~3*x&0O#dIRS52HP4(qo5A4? zn`y{K<8`=>6(#3Opb7c_C+plw6X1*mZO5#9(D=)iFHh$$>;0}H?)id*OUI-L4VYF4 zR$O6$Zy57|qabV2SiJMIcjx0LPo26Z>;s6f%pI%(Ssj34mLSmGzW7Y^lrhp=|8#%Ww4UGihx)~)&C_{~I*rxa_2ZS^meu^SFv?}cU zgj?usI?kTmxeJD^E=lqh5ELRsY*kN1$^^+JZugF?V^G2sXJe-(YK6j-!5KD;_a?+g z!cfU%iv^X{H8R(+cJ105dbb!4k)F9JIC#B^KuNn#pBALGss&}PixzDriF?K-(wlK6 zVQZ7?JCvZv5vyjAb#Hs>S)LQzuRglZ^tUgQCEw@-tBls18ZCDy$98ILOyCo(Sg?l1m5L z&_|H=y}U&?o^?{pv8j#BhXe>- zu*pv7^wDZL^4h$G?36^H?>MuY&QG7+@{QF6Ls3*zWHG>nO>hF}79PE!!`vwc(O!nM z)%ZInS%sV&eDw5bV+vD2LZNHlZ_KxN^W;E(|Jp3H^^67Lctv^{%)gEVmH6|>$Bkno z;n2V1ljLz2kNu2TU0~jn%r|?@q(^jf_kdQMLf$Xy{f?sv2@>3P#xPLUHGY%$+%7o0-D zHwxF0bwuAL;V%O`!YBBIIF4HM>MfZ1NkccPocOECC_WYrkuNVk{-r+Ga9aK z0Rjjvxa94mrmCu{fcm?D1lvp?*7o4=Z0c0N1g`^B6U<=*$|_%+=8MeaCouKeJQG<> zpG@mqOzUOrQ;W830|DKhq=0-ccwZvHS-;GA3klH^L6^)-Tdm1&Cjb2AW=kx%^x37F z;EV@CIx=lVe0U-I*><*=`A96`6r2l%~sX<}}5U)@V&D|3iJMzk|@hEo8(zVFH$ zGzy@Ms)2O6#QIzRu9i+XAJamAUeEaBhYz73e>$KBttGH|0aiUuThoz@-i)}Zv8OgGe!1_j zs#2WiRZ>|4$4fQBXLfs`%(IBDw75KkLSb&v zD2l=nXben8)Ol!EocH}bq<;PS7{oSk+S1to@G|ezdeY-{Bm_JOX(3~F6uUxo-+sS> zovnI!>>f}{4R#C5U`@3F9bPmkewW1-4wyi`AiU1UsebJ-)Q93O)|1T$(-{n!Dv9T< z4OaHZfP{{NbOT+m|H@CDL{fbEgvnS!;5*H+jh)$Q0tx^4+F^XEZ&iX~AM}v8_?F}f ztr(~bRzM;FYb`(hf~YBOsQ}mFjCYHW2C1X~LE(fuozlD?U^Nyc?`v2^f#Ax?iYQb} zdcOi}){pOy(d4&1VF}A2Yp!WC-qH*fB7qY7`0b-5!SAGtXA|+~YirhC0sc2%b-}@nB?hA8yQ@0tvp1qo22DB+(O}tGf>s&3cTgZ;_|ArgEwOv> zELF&Rz)S@o(R7*JgaJomgGxsaT-XM7qQ@1}Er@lTxd4{zp%N##KSMQ_q`A0TOnvI>Mud1IVXcDvG?n>WXUaKLu#$B651 zHt`*{!vqg=jS8aJMZhnZAaxeeIfO}|xR!$gf&$P9{0Am>Hx4dX4B$J6t%sbop{~T6E7pe7Lo9 z(`S(jCE(j|{J1_G2bn{bR`w|Ns|{fCY|y0^gOm^91%wz0kDvPi8JV_7&^N83FttW1Y~+$<4bDS zswMs&i0aIr&5-ngaBUP{LWu7E?B2s~t~nemrN=ZsK3D$GA!rlV}G?phwf=D&+!+J|WF$8yOkhFHT0p1v9bNxPTJOqf4=;bCIIJ zvzQ5!sa31+#ELU#&eR(_*39RUZfuvq z6XtuZwxu{Yp4rbc7!!K*XfJBFo3Q$Up>vk)-m^!E`q-G2c>44PsEnLfBSnrqNsK#g zIq%G(8F{v2g7)m0Rp4)A*qgDKG-cd@$Jl~RG-?q8Fz3@+T6fxvJo>KZj8?m&qS{kE zl)*+tM{rwcHD1B{w}2-O)lx607TZ=9GH4Hx)}cOq?vEuksbU?#c)_<)tI>(ccbmQO z!Go!y#sp)Z_YpOw@bC!*f2f}{@yY&>u9(vL)geXh%#b?63g&c)GO>tyKfij-?;j4bM zYlg0&jK9z=mA>z^Z~bOI2o3nGgj!8-^U`66|M}Ur?to&?To?QXkzb6odt!EJtXlOW zL1(662TEO>wA~PJL`%)|m*BCY$$}6!*ef*)C*V`enBXjU9WN*2sIb3+E&p?4GqdX%RZ@CU{%ouH|L=QCJ*0o6^}kO;ThDj*KEcTvH{)HSylcH{&X{AQJ3Pj~pu@Np z0}WLBB$~`=-J`qL@PxG+u5B|@F+bRCc~Nuka`%>|N#CaTeVq38^T=bxg&w&xx47ln zSKOTLdP@!-dqy`U&aZ=tN?`bo9Xh79!j{lXOFr9K>rS9R$W;RL+RLQ2jVr7ompq4< zkn27{;r{b;Q0;hbPVG0ha?_~u#QCZ6%3<>f@!M#w<4vr-jWMs^DZA;}Ebafs%CPsy zegpqqp(?MG$SZE~zaRMT?_oV<1%oL3Z!E)>(#HIEn4D*`kZcz>$$v*Uyr%M9;e^w< z+-oTz>_+nUzk9M~W(T_wDPN<-XaDmXCf1hIMb-rdSm|tB=+&DOX-S|b{urdC)lwJn zGmHk^i-vQZoB#}*dD=qdH7Irs1ONp4-t*Jq;DhsyxRi3OIXH>GhpttJ4!8c?tn=$0*NyEXF>Hsyt5M_;z=5hD=2p5n_ z=+O3U9H?CIk*0u4E}rle~m+3*(CqphgR^49isO9 zySur4Nbr9*>(W>@3A6Lx-JNIt&&{SOH_3lbE1M)cGxPhSH6H&NATue8eEiaI(u4^+ zdj8K8ANc7pci#WlF}>Z9TCbPR{Xce3F80^Fq`v>Nt6KlRe;jhKfA5*~hwRz^)*ZiZ ztM=c!TdB$>_@8x$xjOeh?|!n6n&w(PlmB^ryQioiQyi7rr9U2Tt~QxgL*+y*T=wMI64frC!I$r zagBqX*tLF@zkJK)&2?)Ch6k+WFsj*tGVXt1bqanoe`SSF&Z7qp_7h7-7psJY5_vu& zeU(t17%sm2`oo9qX!59Pjyp?035|O*xNQ`Re!VX`j4V)e@$1R3{a`W$Kw7z$kulOX zV#>D)AYso@Oyq#6=O8ZD#8(lueqvV3I!cI6pGUaX+q%EWxrJH?swBT2&78}@o{p5# zNOvzi-o52Pldc{gI-%JE&Hu!yQ@+ceETU?jaTl=jsVNmWnerNI!3L%(_MHWAv3q

^T4<4*3k@KXu*48_FZ>H)U6l!T^8n$iQXV+tfTvA6z2d}@{cuLl)_E)`u zyd!2K--QH3KnL7Jl}d^^18L6vZYsthV@Vp~Rumv`H$9E>qkPO`3#)_xC5nw?C)Bkh z8->R=?bj@1mIa)fgpHa0d=(%t#8B-l{*;!CJA2n zJC*p5vk-!bNDmN}Q`hTTDvK<{XXLSkzkc4&@Fp>-EUr&q^Zi=eFNxL~BM ztJmb#b%`g_>8dW_U_N;g#=5X57vk>=j(4#D$#I+5Rw8KnFYC~(2n%;};Eay=A64|c z+55Gi!QA@#aSEM$l8TX%guHyYUVwRE)EL_4kkjDT8}b{`7nMRCsU^JcoX;O;Or5%) zq!?7n5E8HFe2+9|ld(e$Vo@lZ*Vr9GJ$#((-3rZ628cNB@~vCVC`fJqkVRu?J-K>& z*4YAEh5$OJsCDfgyh>HZYt^^j5Q9!KUd6tMDn2VK%Ul&TLIQ1hUD!TS0%L9XoIWvY zGL@JbBsrqWcKpMl@6n@waem2j+^jqYewhIabIM^@Qsd<4<;%3};CZ^_s9nNp$X%ry zB11_k?vT<`vHj+cs{3xKq*90$m?;jo7jC#Fp~!*WA7OR8am2IF}RvweztXgEg@Uw4%3`h22NiZ644Fj*M~}9lQ1yzor?t?4CxlOjf^l+QZil%Dsl?s$Wws(% zntDQx@pBc4L-Nbt+snz1T4@GnI#MR6Jgy*X*dPLk)y74_+Lfq|G3;u!zb>RPjM;LY&LE#ktIF7JT$zRZbBYDky}U$JA|qg`i4vbR zDHjil!HHlC^q{P6_xRCc;xLD{PuWj@e-?8$mzdug{L94HDmGOQT60R(MUh6dEkX}u z&0NxPzjZ_2q8za1$8O4(vF&%@B|YC#&s%eG3mD3&>xTTvQ~thib^yH4>;;WW$PBV6 zH~;>s5xsv6mA->xI9KI7V%ehRh$fB*cOIe~5)>P|;J~qqEPhe*n@eI~EnK<@;r3Qm z0WY3EKY#~Fysb>G(m|$k{EI58l)=d73kpTfR^Uj7Q3u@QP{C~Kpb=Qk;33B5wyh5l zp8d{l+q!iZk(=be^1mF6*L-+9;P3ZlyXY}PcA(y7lr2qV@n@5;M@Ve?XV(-|&*l=& zTa$h^_IrZSY6JXeE~N)+k@J}lUPM-`RI7v*(XI0Ial1o7hESPtN0g+lx_>o+ih#~aG#zz49KwcQpL5@gqYmDhYD8BwBFnt9b;cbNoDZhC2>QM5^@)4D>*6CA?JF2ScR1PY9H^zK3uwD#j z(9WGZ?M}N+nzbe)yWpJ$0@5JUV?R3%i62UR9`1f{OLRHjX|yrFrR4G+;zc*Cwu*eR zk6^P04P;xk(fU0xTrJZ0_g^2)jDKT3yZ7$hn-pXH?FwgQAd>YA{89BIziVur#+3t* zf~;S+Za1(s2APAqe!AcIAfrA3`quire=dYSogj*+6m9xtF4qU>_u|8c-Z_o$8V-P` ze&?`F`-0>!&gkie_?0!T*K=0ILa)Qo7H3>a3G36JK6QZNu>JCNZ%a!{{)#@PqFaH(GMAF}k)wHy6C6+_ZA7Usfx1|Hq8Q~= zO)0)DU!Jq{#`owKVBtoh{>hJ)W~Ippuh$YEwV+F)ot2aJ_+JOIEz2J+Xyg%3s5>2gXgrsfbo)QwE^EeqeZL zI28!0^XRWOU_oC+E}|)Fhz*-Jx1_asn~?q8qq83!W~_?uO5uWY|Aw(A<+xyZ9j+(x zZOK(7QM*i+$6fxAG{Mg9S!wAe8ZdL{DGXTLRpR7>D}E>6&%avjyJSfpfR1ue4G!pb z%xKuOX*y_t+rSO!zua=n$RifCe-^BtIt&MAKnY3X?CIVi>=J~JigA+HsoRJ05%rh- zr!QmkitrkHDIKrT)JvnF`YDI_z5o?wj;-tCUNmUXnqQAQWxyQI!@uPL-K+!}0h7DW zgQX%aMT00B2L~PA;K9;UT4X`IzNNGX=jUlrt7-eLKP9^(6aH*glDluq%a_g!;%M=% z5{ppJmb~zCD6N?H4^PWcyq7~`+MwKqcvyWoDTaq`f2f4p1)!wvudylM6p2Ka!6`-q zg+T~b(HviS0$^5*<~MK)83Hc_bSPSEB9bU={$sNSqUqZVCryJAQ4AKBFmMo5%67Er zV>lWi2We{!7MAfLLqpiWd+SWVjPQZ{Fy8DtyGZAE#K&X*%duzIV8RgWXUoqEx$ z7qJcxT`xvsLO+pV178BLehkjc?j8}LWlfwEjk~&hps7_pV72o!*x`%bjIvjOPn5~IYkFm5esUR zQq7Mso|C$=ycT`WVFa5?5nPbR!F;I?T|J0cXY<2|>x%kcKR(?;8`crfk%gNz?36Zx zX+ru_?H0W|O0DhRpTk#J65K**+K3R3+y+xTob_G~Xuj=HmA*tj4!-4e>EQ~4S#E>H zOGl>C0&~F&cfTw{ggATY3q9U;yNu4>pw? z>-B-c0kHpx0k6=-Yw~sDgvH1^{Q+)U_57s0bPDq*iTfX$x;304B#7I#ATssr;;)Dj zmwbs3XZ@4Pj;I4*-P3Kg1jSS=>+m7 zv)p`q>$*^3%z=b2rCi^+Ed~GQFJ9E5HcRFuKme+0L9{)uUcLJK*|W8zUhNj%KfjxT z?4^thL2zY|=WSl?Yb%0-Xp~67RJ10yHW4{eq@ZdN{=^u&WC~VIKF)U8LShrsG3`i! zFY#!}tv$g4EPXJdr#Z~uAO#o-F%H1HTAq9+$s7)N9V_@5%MILzo~YQMdk3_s>6xo@ z7t?5NNp#Q}kV+^X?_)qI<>a$d5x$LxxGTFJguxV;YQ9AR=|!&JiZ z+e9kp)K1Ct>!p&lNp|>AlYyHyZWL`wT4?*r^h70{ebNWu`;+5F&|E_S#=ij8xbwx) zJA|16HdM+$=+#rx&_DM6bUho?;^K^nXNrPtVw@m3x%7aSBWk6FP%g#~Y24~TMYhL-$v$gqUYSpM< z?bAx^KeD7&AH#QZm=&d^1AGcZZ05v?t=azUyloO!AFbboZjo8U6%ityn-Ih#^xP+KPmVsM*?8;w9R8e4v^wrDJ*O&n9!Ky=)Wai zp@K-cJ-S%kOup2#{YI4ChI>Y&(R^vf-6iQ1s9D#@;vC-0%(5y91PMD&;7NjG!_)!o z)O}LB{AS=_+xRi<{n3UdQ|_8yyPO5?%Xw+6(-yyKeiUXU?6 zYvQQ|DTUtHK9RJW_{b)W1&YAZC8dyx!7hA1X|7WABqD9V>gxW|Y@zQtiwT>SIHxW; zOMw7UfI(-2aFtVUXWg%+=|zs%<%#NB@&pbX-g6?~GSJG8O)dvh9rC%3nE!f!a-+V@ zw6|gMMzZxcyk8c%hts;`ZX(LVh`9B(xpw`d~UWXI2`cZI-hMt^OcClU{HZ+-n>w{Fki)DJ#i1ej-e(<8f#H*y4 zw>bn6$0BG(yqhDZ=v;YvX3`9_cmz93B~qZ2G8h8~CUpz2s!wbx!?_?MN(9otU0rZX z_A7r;myUNJZ;@wvvXY)9p9Wkq6LH1&U`>%)j)*gOv79w10S_E0B(z5G{)ZlP?a~9S zUtLa|vg_itP}zZ6mQ;bYQA!ZztA~xUZ4JVvl@jQ#7bmKwl2l8EVSCb(UalnmN}UUH zxja(6;h!VNx#Zk9VYvUH2r{CIzMwP)Z{CdLW0uJkl{Rf0{JylI-QWF{$1TEJQK1F` z4gmRspddwev!cX?tE0x&gzv5Oj=QBk{|^*+AQny2t^B^3E^4 z?&3n8L0%atQUrPf9)hnwH8iNv2;0^Y(I_=ex>}0vB&SUBhXB15k?_!SHbokq&B`C$ zY7SPHFlz${w>;Zrm56@}3PNeT%#8&hLArqY5wd`xz?5L7|}* zM6vv$Hk**2Cf4yew{Ys~qN1WAa)up`GH|LVD1Qiq)UgCj#(okRND*#@AIfnVjEyk} zmg|^Kfd2U;?SjAM#apSVZ6btvP8A1?eHpp3^a|OebStB4eJ(G*EC(~?p+>_9joaJv z$=F33RH(<%fh@+4%6Hg_Wraz)D2SL_+U9^3l$e!oL{vRq$Z#{0B@`M(y0miO0dRKF zKyB}Qb5aSLEmh_}PrHnF)`@m?cxk_W31`~3Q8(0kb3f3W~R z-p*Op4M+@_pnEsl$H>{Kp>gjwA_xZ<1;6QWO%5$~<@ zm7}Fv<&25=Npc}fgG&7&uhO9$if$vn_5+Bchz)-XUa>ZSA&7|Ma2;zg6NsJTAMll0 z{&{nfDoEdq%3o-GS`)m20MFjD^(2y6tKTo?9R3^nX#KRsJ_*xr-|W3NDB6*RgGAb8 zdK(VwH4<64$XSE%O9l{XL=3iuz#JJTxWl@{l;lzQ>5Gq33sdJDLzi2xSn_D-!|smj zj~pAlVCFIWrXp==@kTDW18*#bOzzI>izGRHkgysOl|jImX4~4A1GMj?y{;UGf4Q@` z$IdANjpCvZTXq5ikux_uLq=p}X08RSp;W)JPDS|`INun4FBTW5Dm<4|5M(2##Ang0 zW&PfNoSg`+Skcj)(DeBEUYiAu0(2}W)K-u3he+DG*w3u<>!V2~DWHz2^#KyG=2&Qk z)j@!x@v}zClKwTyE$5a9y#w@@FRUfkXwufO4pqG9nAX|UiH*LsEWw4ISXgKX?T#Z- zRq+*f?~UF?$Sy-v;1lBDbyVlD+7B>z@Gt=jH{}fCLtNHTU2cH399`?I2E3|^FQP24W>{j zz@+`uKGipnD$)Yi5$~2u*beSbs&EKsa&mzqVFQe$T>t@;VyYIYqIlvU}38K!)7eS%swYm7G4&a zqS09&N0EVZU2OdRKIv3b@=#XFQ|Vz4jsN-<;gx+Nbz9E;t4&+rMbF%WL5+)$EdMmJdAao(+aNI4!y z0ucaU0fHndH7`;4%7d*mWep6MuBI(jMwyMyxdm|gPNlUMz9YP3a`T%nNwY|)MZyqJ z5XHAUFN1fAI4hQvckV+`0+cRWNYoXfb}98s4-KO(Z@}pSHU)!hZIgsi)BEI(1?1oT zSA9)jcz2N8ZW;2IRe0k|HHOjO5X%e2jA#Jgts#RY&@k|{tV?$yH~j2fdHm5E7*A|+Fzf8dKW(B8XVp7-AAZ;08h|f3 z=WEId>@e^3N|Jgdv8Pk?%(B_Z-T)d!%%9!2bH_T#0~g8rB~+CR2FIWC-s~*e zYBv26QhTE@=3#S{*V_fvyu_`Dik3%?Ge%zV=*7={(%d0%4_1mzjvZT2+35Qv2b-&) z6%{{@O;2*!tkZt7gPe9kITgm8KALHlDBj9A5R#s|c+u@x35zKZW5f#36HW)r5&=7C zEAifqb8dmw%6sfQbuzBa_^jdygJ5J*bN6<}#?|#u5vlU5F&jDB%?ZN;W9tT>Nn`mt zdC+qzHA16nR%uuyjME9lIjC0}nY7)po$~qoxYu79k^Y`#Fw8f6MR79c{Y=RiXWdg` zTKY8reh~Iy{-Hu|8aC%Rh=`BTo<9&<+60fFSarGoKx%^uMQs2tjqyjYqfddEbEL?~ zgG{4zzu0v(h)Bh(%V9_Kra}%yR2f?30Sb~0sq#TD%|(H(9J(>$%aE~Fr7Su$na4U|;I95 zb6t&`r6RTp>-qxBED`})M8YN>Ys`GYk)1rhL9tBE<8nRx+6#vwN8;dOVWZ^Vv_6Xm zyAW5!u0iW&9?F;ghgDzR$l~8iE6Z$g+L;Xm;1HftSViev@zkqoIRXXOh==4wSp~1> zOe!5iFiSHRb&~U(S$Tyh4-E$!91Oq#NWpvIvyJK7D$P3)5(y|PS_zvFoU@+X&@>hOqDBZ!l1|7i z95=I8L=A)GP>d00IW!dMwf`eb>0T6qqvq?cVD0+F*l-H+v384VK4XafFRIQ1uIKi9 zki|6_M zyM1V1$251m+Wo)EO^+8i$=l%mxos0g2?HHe zM?Q*I>*Nx`&PM;Ohba_3Mum6W z`h3TPPiuPV=R6{Xg4^B+8yAc!H!3_QEt6Na;IU>j!l=X+Y5|-_&Tc$euu0{w@hU1& zjn{p+@D{egZxiA}^T@q3DhFmd+|l(7yPZ40cybdwh!3~Q)1Ui1-l{a|&9^?SqpCuLvkep=E2 zd-jOEm2@=@&uwd~a3_l0FZvlOT^1V)I_SJp55XYi?xj-*mk=SL!;OWA+i+ zzTbkbUGjwE`jaeJQ8#ZN+93X0&2iEax}wE6)wzt z-=npE`y0My0ZbTOh;x}pH54FHfaawwUUBiPIw!dt4lU@PMO?DVdtpk zDy2+H)`uh7$_j*Z{ov=9R!&^9uJiMR)wWo01-iPv;v7qZk(H2#h=rlru$J?$zjRf9 zL)?Ro#9rLLe17}&M_3FDp>1)PyS*azu$FaNz>gyPn<1mdutj~vhjOsl&$HSE%XP%2 zue3boM@ZVun_JI>D$e2!Ig3Z`+=Z&c+*x@|&Ibl)fk)=Aiot7=9WGeh|Q?ekG;RWc-q7nN@4!A z9WcO#0<&sz3$t$?&u5f-;H~-R70BLt@g$G@SZO*aTdVpd+m+Uzb7$O%loer(Zj5hS z(dE_(e~824i}>eVcb(5uGVUATrf>E6aKCC>#^?&apRzG#=RUEFxSP^J zLqY*Sqf^)?+XoNKyPqd6T>dd>YjJ!d)fv*gXZXsC!#`Z;J?q=t5Ie3j) zwH+AK_?+}aHYegBjSyD;mC{F6$zh|VlQdqUfVA0 zNM?+c3Jpy$-0F~2;&uY{gFM{qT%8%0Lc1v&bSpSLOU$twdoP^!a3zkGE3K2=@U>iH z(9KsiR(^LiXnGy@j8+@$BvdJki@oLXrXbttO^X~%@?p{&|B@X*QTqVE%y=mM$ zE^9wqjNzoLJ)a@4xTX9~onf#3pAV*doyl@c)Moto49ZH7CH|k^B;7DIg=OXUKensl z!TH(jk1_86BNQSmv|L(&#aXxid`rb)*cTcLmFw8V4kwYV(r+hl1}ct40`P7)J{)qd zAoKqEaLRmgOU=J0>8Ny$WxpkN$EE0&f{p0lIP&wG{YP>q!o`jDusw7X)q%D})TX?q z)`6$BGUc|@5wlzKAUAzAC0-vlzQmn7ckWkM!12qyJL-pq?wga1b^`+foc~7YKU&vZ zVX7OJyRh}vWfx!J-kSD^8Op=6)arb5e_rLhLIb94Q~~lWs2G~y4bF-*#PVEZSHSONUYaXy8m~{yooR3H2@gO# z9Fb2*e?*aI?Ap2WN=nK|w$N6$i=t%$P{(N1r=g~)N%_6eIiD-929`X7b{$dvW6~rT zlIBK(B(=F!|4D9cap^C$t-m~EZgsp(T`Vcb=6g?I(ey*7z&&!k5F=HV>!mMM0v_J{ zlh4Ghu2dh6uq|TRY;QaS_a#)qx;k3mJ;6unjJ{7gS{$}I9QCwzy(cZsQ>E*00g2UB zzD7$_vgse2m~t!fbgIb}^(WiRIx|;Sk!@x(un5l!_VZC>p$A|Xvdh=~&g2#?S`@hS z!a_yf61d7kOf^^r5`Lhd#(Z)~Ib(km#igl&12luAB>(j6*;FhfFFm|CEzQs5u=sR5 zFW=t$$dMxvwjYLmU%AFJaMhk218fJYb(_{wVZGRG zmAccSN>{Z6f#~~wejbIv7S>l1de!#BeK6%C{YJwRVcW%>QEdu)jHWU7hY)%VfPmgE zZ-0en=~}XN+V_HaTvQ5EodfaoV)k2~N?tAA(Nhuo2^*Be#o~-<@I2jTsY#4`K{%3<~!a9;XI7$y?#JRL9C1hbmTsuXl(G zCw`^Ld`QW3A$WUTXp1aoE*^8`^SAZRUOqJcXqoa=?CF-h=|~7ro9lchZB}vRuiRPc z+gg9lJ-uKoTwR5P2V+T_3%wNQd-$;WAi{@JSa)DFH~ zyQHQ82HTceZ7q_J6=L0TkZzi4p)(8$^M^l;#I21<>dcvmiNqEi!na2XluqV$1ew@UI!yulC|((f9}3 zjL9&))5z?Vw0@$zQ!JGfK1j7%*_6> zcizWHu8@$|_^5#rj1w;^g(J)XWvs65OkFm}_D3&_ z@~Yq;jRWDn+3(Men8*Fq9d;u(h9Cx!-X$e$sphlf_lw0*o!$7zD?l;bHH_iJx07A4Px=&x~GAgNY$CRB#jLHs?7wiEQvI~K1DAr~zXpPOx@6Zvenhy2-5{tO;C)>HZ;M--Hp`}id%Tw&yxPFw!suTK|fbeL{z zsQHC)A6z=!Kud=M67bWF@r1&g(-xp#&%hvw!yj33DL`FR*-;^>Rb<$eaE`!nnIdof zyJymXdntOXZc!1CaAwlf`ygy)*yON0G9~`dl(S6X@)iTgJ4fryL-G&cPhV}ejB?^N z*CvFloAV$BtKS2(b4OToo6taQm~2ykW`S^d(fx|7?(JTbvcEsM1Y^O~RjRxUMR&@_udqS=asu}aItMa$G-P+n<q0b6A+O{Ff)D zJ#_`WZaE(?H0pfq6OUFgm*|ir*O)evp`b6HiuvPVY#gPACjm94>1-j2xv-7t)xAif%U%LtG=iRw{`Z zQig|brxvjA>EZ#dr+fqvkhGWi3Nv|BG&{u?_4d@8!wdYrshVX=zbh(|efblLj#+}( zk1Hd|8FBLIKC9SiHYbhx2~&&MEj=#!QzE3`vgzHZQKO$ff9}UGfj+1SA<5u;^FB%t zznd%Oe9|gB#8vm(k+a6@moN7}`M3mM>itjD7r_YXSzc8vTST40Lr3mjM^NmUwqx;o ze)PfFKw9XI7Rst-QSf;%>A6(eNCVGWMXRf*2rUlu$fH_k-AS6)$HvCa)VMdNWZ!Z+ zz%sy*1+sT`7TWEYhwhY#Utw;U42XIg0nQ?$wb_JQWkNYiEEs=lnEnnZx6L23n3r2jeJ2e`f4wYtfQZo%Ll z7i>|sSS(g~X>wVVHR+NEfc>lieHB%{2~YkS3BVG)7d>$d5ZMw5J99I~3F`sAxP%c3H+iRiOF#I)#!X+i z$O+rj7>_hd>NY(DeZz59eC8@V{7y5Th-O&dE$^q6{jF|EsUE%2}&q$jk*9Y zO9PhmDleuO$6)F&c{O{-e7!ZkteB+rs*wdT!K8 zQ_5sd8cXvgG(lTpNW-8G`U;fEwoCT7U{P;J8T_pEmTAnkmuZl^w)j%Oqt#DNn9r3D ziz<;0^+X5Cjq*@-)^uEV6tcC|_3PF(xTO*T?*faUGcjovPe^Frw7^kZj~Bs_ik^bOh7@Ug?3?}nZzWGTqFi;iFjF5nvxO%ob);w@1Tg zDGAS-2qXMq_zENP)hA;{yqdO$xPL$~6Akc_yM8SkMqCl^X?o%3U}~J{yXaSTMt<2v zu6gl1#5W46=ut>v(^164yDikS;QvMQ!p0lEWLxLpt~SlZ!4?i0>xcGMe078{+|TsP z@J&rcT#*--1K($sQ|M_@@g|%w#grLYqMK9SDgtY`;~dheN#_@lp2y629&Z4nUPXWy zvp9@wG}*tkSUYhx-KBEm=0*ELcESlbOzfdt>Qr_}Vpi52n z_SSh98o)P2<70gG>83}Bok;P5=LKE{`F@O3C~RIR*{abI@SKIIx#aNpnf|?HJ`jXx z?h+a8hy{)%H0n;I8noJ7AK6vJlf>J@+Q=BNsTIYV4Cv%)^YyoHKyVYY8-+EySw$3y ziqP)6?*0{R2aCyz9r2n%u|~xvNUAl!CvU)3u^K(x*tne`o)>sEfP_X~Fme$9=Vp+VE8+RqJJtEz1xD6^s=;#~;NADGzNZ}E$Ewu#}WL2wHUBYdI8$8&J~C*GA=JXyd4Z!Xz3v#u~W`=C60S3 z{HSq#6|*7WRjoVA9k9=bT(MtfwS+gdq-r!4v1J9p}`e~0kz z!$z#d;2pvZ4e5G~btjBo($b6egE9^82wJUK%^<=qooHTQv(eyh& z_WfjDG&UcF{UM4;{)LJDsX z;0x4z90OMuT*mP=Xwd2(l8v7r3}0lv!KZ#q#e@!@{>lAWX#2 zO1#!&nFBZ7$*EK44SLwT)uE6{62ogkzw&LAL7s5K_mlcw)o&^s0oCTsbxx&& zcg09d-OdU0}PVC-1pxPNX_SIV@t=&9`d!ep1fzB1G8Azx)lz+CeiqbF%D= z&X#s{0K}RI3z`$5vhw>>q^#6v)#XHhgs*dKkjUp@)U3>;I9`X)H*EH7ZNw^CoxM(Q zYt$7wJtbskmszA^a}fDXCwCdL(!*pP4}Tc>lgxBhjJj^Ae?08#I?Im2!esumXe8ej z;bOQ^vgEwV0_<=9h^%5}X56ML1hUs!2=}O3oF5)>I`t+bBt#r!@E`3Kr+b8)t0jgI z>n~imVAf(NqEs((sGp^!`5u}dj}UTHjVjA}5CAsE$1x45FZYXBva#k(Kc3TP5TuMd z-ZZUbhE6(RdEsTki47!6YL0|FTMQoh1Apw z@ve}eAih?OHy!<*p>$WgU#pYeig9BLi}S5c&3SZpc5-kEm9S*r;g8Dtf4=Ai)8^XK zr&E6DI$30wgH7xNtQ$P=7x`RM1LwvF0qPW%z_U9BB~OYcW;_puUU;uFv^F`m!XVAh z&3Ez`xcI>XYP+Vz@6t(EadkU+^Vo6ep+YXKDp?Mqhs`TKa3M+5w=KA^V zs!)-yg;*`|oc;$(8rPi;17Mn{_`i|PK9e+dpWe1!e>u=S#MXK!% z>mu&5>x7}YlZ$8*i*NkAXXn;Zvo-(F$GF46)UP=t{<2B(stGC4fyO01hVl2*md+S| ze$fmzDE`-tVUO2rmD>Uofuus*PY|n#koYxn(P7J%_ko2t9{Idit7K|%epozrd4iFT zJtsv_n_j`Tm1uC4q^oY}(|g(3tqCaCVMrm7v^)N0YcF`XcQ`FmL0 zQ~Ra*>NWc=c)2Fg_BH;5*XaQ*8_RLSfvPY7sbeGkH7|V4xnU^u%--5|aK;gefQ_5>{4cHU zv_eNm!C3bD_qLo$X5r`%{v>&&>S(W%6@Fs9h*&|=+?YbkT`0;b5qC2Y&~2Yd|J4v?yg*AkDsx%vhcP%xxE5LZ;uHHqVCR((%Nw=k4RrjdLc!4KZ@ z(S_$%q$6*3d3vTU#6(F8E$56saq5&Y2$(oW1HL%0%dVb@4KRZafG?##;GhhgASIi4 zG>1E~E(3GqjmRT)1Lv#eMu1z2MXZRlVJ^m;u?c4~2^hf|@$JPBniAh5AoPlsPQKti z1qFWWkwc}e<`5d_G}W)Ds6aUSu+ri4IRPTnT zul?rX8^cESykorilJ1$N#dAi5*>_ewq2gQDzPrQr*Arhc>rQA$P>`xO*lCQg}doMUGvLx5TibZ^>w ze)c-x=7;LXjU0(D-C}?n)QD-Ox=Y1ESOnVe`IrObPl}|iInOCz<$%=)-+(lSIXyLK z-n6M5O%+KaIgE*V<4XGR+=TrYbJX9z9%j-~>Mi-4-AZic=QKx|o=EoV7to1bKn4W7 zbj0!kI%F6^+4iGS#J1CR#;R=(OQjzyQgJAE(Tfy_wp#86!RBTb6%ZRd<5c5w}?V+*|EaDZRc z0Jk}??OeZp{RFK6a@+zABrdp=*`^zr$oD>@hTL2v0D6nwnnWJ$9p_-dcR_}S8tP>Z_ZjbLW^NzYt zJd%u7l(Rp6(h&KI!9ndClg=$@TeDuvNwaqohZt1Xfb`Pyt8e(JCKik34j;O{jtp@6 z@;XTcwMQM|3V}p*mBam;H*3k-BB?@BS+!%wgXTY0n%DRS%Gs$SpkDe;cX1I5y4%U_RKI<4 zhKUQER0NcnuaOuO>WLqUb810;eonKZ&a)esc`Pz=Q13ixYi^HTUAm0$I-l$aa*hjZrIIBLb^pmw<4QT-$G8TEYAGi7XJAin#fb!kqS))j$udbPMZg z2L>gSZZ_4sN4-$mh1u^62?--dD|~U`(xpqjC%d-dDq2q#Z728JjWV1_s&DBA%npk3 zl0lIP$1=a~9c=PgnpdZ}S6Yz9?#&&X3Vtl|PB5Xa-+sc#B+F??+TGbXhNPq_>{6MW zcBuJ^2VCHL5M}hj%eC z?M@@eW=ACR1m3xN5w}FHlBx5qw6s*#lYss1QBlpBH8r=MpUJKZ+ET)eShCZIlRYiL zql?aPPNq<6-0m>K*EGp`N4@h7*=cF@*#{b%OJF-EzF|pKC}V(cuJFEuQEgj95{_ey zL#{LZ)FNthBCt#YjA&x5NFQ2b8ltKuDTBt|T%@hBuaAt{Kwd7^VYFnrpnd6Nl=kAn zvT>QG$DCjE&JLG4qcmtL)&IT3@^g85ctdsDavvK2f47G3s8l4CA(_x})#;!*cIm?g*g!OJu?Qbt zS1uQ#_uZm>-12h1LDTF|!g=4Qje$9DgW<;o_ZPN&(90F2bIfcn`yD=|Zee6e8c||g zTge1IC1Me|Z7@ZNh=_PeWa(~T;C9v!dlmjRpu7eL&xu{V@u8fOU;Fc0Lkfh@@ zw!~hW$z@jYHWz!{yMI4u^{Wb^K|oMYdyOI%Kcz>Gm_x@-8{p4?;ho2hbrqbP1!0eoXNlflLZy8Y-+T)`SU7>eQ_}on%@P zt{XSBUVNV4VH)&{w>8X^3gZ-|HpUBo{_hhA%!i`C7oMkE+js@0U%VPFmu$TpIJNXXI5 z2<`PC7n3l?m2T!qR9l{~W;7|H^hY3TP)Ai+p5VEFctoXI7hX^}HiYI;N59soQ|BJ` z4Br^`bLE#_MP7P)Ot{6V2Jt2UvL~$OixZs429T&Mo|R-D7Z=BfO*PtX)ddb6YHVCx z`cPZ9u7QfI^m$n?eI7UIET7{}jqbiFSFR}0FCBv4my!}w4hr=`SK4;Kr36kO2apSX z(ZL?=%zZX*uFH3S`=}T1N6%-g?Nc`2AW;5k-3~(m;wrLpcGg*#6LJ*GMH0p}Sm3Ow zyl@5=%|R%3P)s4yQ0md+q+zbxC;f?rINT{rvFb+UtQ7jEQ&Y|h~+>kAJ)V# zGGo-y%WC**toAX24PhLU6H8vdph3BP912|UV1n8~e$&Emsf6nd#r>1tcb|_;r z1t}L5Y1TZjwvw?+D;SX}{kg=c+lKA-jSdB983Jm(?aZAz8yWrVCA>2}zxly1*G-!? z;k;6dMWr{c6CWQBkD)8hhq|sh#+J9<=jH8ZU1StWX7p?|a^iitn)WU(sdHM?{*o2- zlg>-$(A1)6h^B9!j6{Pm^0MhYV$PTR2fC=@p#)XunZgxN#+D)-&khizm=86}={t#k zzFS?*A64oP9Xhn6!w+`!HD>&B4lMgrkS+({wLN>cTvdXHRBS?@CDv8`{G57e12=PX zF|+TKU^d9uvLWDS19kNspI4H^xSa_YGje1i*^0zN;ueunoxYCTR^VJh*AkBwVx2cj zajSJt-_4tYSJ&iZ<%Xvz0kYwLt%Ng{RLktX#{QMGVg&e_R(D*mlqwwih6Uqx7Zj1S z0Ni?II?ZkcWK_~&s*W`&6Lup<`A>U=d&{wMTk74mop~=66~C(L7ur8~b$as>9=^Cy zMc=>IkBC+H&&ZJrIRxIHl)>d^kn*qkmPHg)7#N;kRDyIDn(hiR-PeWj4Be|g{ z?Vol1go#>pB>{WMi8sJF$V+)-@4EwhUPh|e^=I0TOpc+w)9I8A=XTwM zIil+F_;?j1X>C)_)Sydf>eglec5H%30=0n~-5K(wiwn(JQB_Ip1iUHe1Id&Lt+td|@=+LM zHyw7W%*09WtM>MO<=D=D{kum>&cORn7Ocku=(|<|21bTnbVZCNJ2;r4elpXZ;pt3uV*{2tlDyOjffk zrZPdro@$dOM>A$XcO-#WX{(pr%bH0L1*BYZUyoe9ssCmdrrevuBT( zfRMNAS{k96@?xR@sU%#CFZS;iaWs3*d)5^E&a=^8qiw;#-h`ZRSP()Ulyo+~4^tN* zb&XhX*@2+%x9aoS%-zGsEZs*%Xw%NHMZ*Ds3oAu5VJD~Xa1Ut-VRLsc|K8GeQ1 z#@(`?I2A^=;?Fy>JPjV1*^v!5-@JK~&WpAiceRwL_9`!LQeJ+zy?_v0ylU=;egW|S z@OM77IUqolaMU6GGbH#cj~}1Ar?58z05MMP(9ZlcOz9>Uy6%J0!cKZ44Bm{dK&x2x zdbtnVAI8v`HOz{Dl+!lm(t!hQS<(%h{G4BxVO!ECl9`dgv8_)ncy)D%(PEh;0o2>H zdGjV0x*hv#rybFnOd*{v1$6VsNeky?U)%^4MlJeG;NHEPNu>_Z2ZCBN{X`)HKD*&Q zr~J&R(1K5R1M?JRccUN6m^D{m|87|RnDCkYGaVhZyF4TxSJ``fcPMto(gIGXtgx7@ zK-#D~R5~|EzJ9${_IL`Tmb_f!{IeV{&$-Gg3++kE2_PRw#1eU5!U#wt53zG}v^)~F ztZS3MvjZ|W#5woS-&IF^qDUd4?kyx12$n_0x-F1gw5FKjkn%Tq3{|MV=R$0o_Vds2 zwp5RdA2B`toCmL1UxOJBbMayFN1LIugih*8*@+~loisNbInvp^y!g@1{>zRX`Ywi# zyj5Z5A<{`YC%yp5N^8vXUwhM_tRws#cunr#;@7UNQt|UmeQ8v240NK?7V9=K zbaZlZo#D)#CX+G`ox#+09o;}wXTke60QCbh1b<#h@!*-_donG$!I zRv~k1>vKIi#MC9RlpLI{wp(#!I&cBFP!VXjso+$w2W>}dObcCq6urJYP& zu}kkpG=TOxAsSw3;z)wge72r_+et0=nIamL)JU}eAT z{w~I$ZyfV0eC_H_%gL`NU0QMaR^jrJLQUzm#2jVMi7_A-R$DkCv<)U&PmAz4wJ&*HojdJ2c9bzYcY1s+1AJiT4@l@DRtQ z&zf~`K|TB9JE4rpAItcUVBg0p#0Vt#!L z2cqfAWrd@Ha~Dz|KuZ~K`Bc6q9(6>N@e3zJ%)4c>chJHIc$ve`mH{j@kbu|(+7gfo z{zlk#Eu#f5l}KF+SAB*I7dD}3RVC0C84D%NVw!wgw{Krh{a{Z6j$HE8%db*nFg>>P z`$yf?AISpA2##>9Zm=Tsm%_q3YCKLr4o>RV3N}-6$rjxVQwfB|6iqX+o}x zFRec%ZF`;#1ds-s@Xzv>s)+?Dxx8Aw7L*Okfe5kI&=wL}Ef=`VyMbhQmHT|y(ZGw? zeU>p-7#>JSnStKvUF2bYw>}TG*XnOX3iIlmv9V?fh(98afYmh}XU>#j9E~{B^kQCq zCHO&0S`!;P|MF*)_^M*45j+aYM|YRU-G!P%Rfqt*`YDhh~9=<_HOuo(u)V&QyF`C8lPcA z&|Yt?_SIxwL5-(x57I36{EBC;Z{sP}lm3pgF&`&A{}+`N8>A@xRY9)oOMlRg6x#Cy z3C^UKj=5s#b1N5wHHi;i)VMDTxfxb$6WT!HpQA!74_GV1%^*)~OVFAbx7f_Nhy8?C zw8XtfYdH*9=aV^>^*jL!3u;5(7eDB0Zw9z^Ao+4z+M_I!3qxhk2SZX&GnJ6hd-|wC zZy3R@8@S>U%O$t4ctj?CoUnM&qT{q(={Fl6H=3JR`ezOMeYpg34ofsb^9zG7f1*j0M}lDX#6iO)<1FJ$o}t5 z_YN}4%gYn~0@t+8sdEp3-DK=^_i2z!FHbozbJi>eo|R|vSxQTYS0!VY+%0BGrX2zu z&l8l|#p%`Ga@8d|2wU~BmXR7V6_@r=tF8V!Ips2o0@6(91z|nDRP-+s`W}qa#O2(U zw|*2;alIjf^yDZRwH~#-bEgFc|4?>H0lyg!a!oLVSTmw_isV` zvFg)HH9&_saH_vwd+{-N_hMA&`Bk~Jju08Ke$ZK7hE*so(OHI#_AUmd75GSsTRVn0 zFRgthUF+YPBs!EtO&G2<#xALu{V0@E>Pxm-`-bkJUPW*d+6qkllZ)kT|Mb2S*>WdV zS9+K&R$JX4Q;#ku4X=_S#)Kx>aJV;MZBy>)*Yg;^& zUDViRM^KO#hF&)-_hP%(^z5Oay-6*Tga5|sJPXdK_NlF<2V}ROar^$k`f}`NuIUBC z(%9zHsct=2 zhGjL$?O{j%`{v!dUB68dNA96(VGUl6)XOrust9#fCkb?Ua6jV|Jjxw@8z*05u1pFuExvuHBh zotFzKPGG?{4ju(k{SveB$6c9xT=7Go>F>kNAE@awWw1$p3+I39rsy4zT*RO(?CDty z_v7QeckIV@qiAnt=cryicpFk)-&>IgHN8Kj2J4>f#guL@61>FkktKtdvcW)2;riV#MNJ!Giw|D52_t zAR#pOG`tm9%+_BXLvi)JJA)3J=UiERb@h(@w4Bsmb*gZ_@%V95?qKu0osvbLN%_1w z3tx@IH>7Buv>|=&%6M?aZ@pcemXf0C(k*A?=CIZsXBM~npF?wr>LeU4RM^uNO`FOv zwUUR^7R{J7t9_?V@59f*Eeq4y!*=duxm)ORWGX6igQo0Owb)U`blKeXK6yh)x#}7kmFW|rNU^!0)Hc}y6nhV>2AJDn=j{!}ZgZ3WyKvYuTZVj0 zO5L40Kj1Olro>^Bi;IiDmX$0{KfR0R#kEoEQ%+N6b#Slx`i{&R_C~8ebYBUa2OO>+ z`4se`at{TDn&DmM{eibAzia>Y+EH_(&R?1sr{<@Xs4hM^I`yr~|NG8;`e?gvI&fh6 zfuOx>FBWeWFA(}6o+nl_9>TjwHE_yq$;w6XK56C1apTmjuBCkL({|YNbUz;xHQ&-^ zZQkB5`T+J>cq#pRG?%XKo-JGJ7Zhdr7yq65Rh&Vu-#hE@qHLmR)S(9fUVd6vtV+Mr*}HwRcrKnyl~5_J6dr zCgm;bEBh_d*zB`aqhKxQLGMo2RW<8mmN{(1C{w_QQ4tysXPUHfdgeH5@#3O}?>lN~ zMH!c-jH9_!_L(F>-L=1ko~gPs(wo=;mHbgSV^Mjz8^uF+|8`ddBBMT1_GK=^y#iRCkdvY>u-+l_oc z$$5S0U{dG@;^jMh`1(vIld(4WMaCMh9doNT#mvV!mk|};t`=-iPqd$R2zt+`R|7#t z+CBza9scUvttd7SY+5gBE;FaEJnnwdzN5Z~94j`CKHE%2>ph^A4AyA1;sp?90K3X7 zFY#Ax5a#L{g|w)pr&yn!>w`83^IGay1J z+VC}_?<3mRWV(^A`K7a`R4sPX6j=+%mGO%K^AlH$)@jw}$$4SOuciBiR$h7T4-%05 zd@`Ne;lvko;#FxhD7zE^6I0@(L59-RJ zMXrn^g}+^YE2GF5dJnKT8)psI3fdSy`1=8&x}I@@88Jl9YHiZrA%7pmx8$)Fl`E-_ zxl?Pj=$?m1Z`iP*JtIbzq*6?o9rijjluN=y*}ip8t{zpZ+^8^zzxAR?|BcaUwpiyu zS<~EFH+|)v4^wODjS6E6^dFZv?g?`5u8(K(x3D+jZ-~L#@x40RcTH^kKO2`tV76O(Ch2rEjZs{Nu^V@NxJb6(zU59D!J_(8 z=9UPVDC~7~!~Z_n$Ppv{q15yI2-3C|5LhR2+^*5gS4%|^E67B8c&JW zs`2ripQ{FZZ&=g}TGJN%gtQ}#xMSmQ9P9Ew2a$FM4c0ij3&a8{>F2djulELw~xT}zn3NLn}ZXgLd^v~%V>b*6s({#ujich3P#84FPYq4gWV$b?S)7eo4W@ zF~^@nOTn)3nSHSQ4u(u4sv5^BS+%*jIy)E#MzV z2N0tvINFZ+{b_)Sz_fcdCs?J+4XutHrTItCS9OyC#bjbH&HvfO515h&2e8%|6-d1t z?Jpmm^hiHJ>9T<0N2+h?tLvm#3u)T9Kcz9Rfi1nKUP4*n?I9DJ$|MV@pVi-xx`_p| z=SLqO(~|&aLNoCE2$&>NE^>Mw|1Lo(4ao1n9EnBE@d($!zsxUaUi9^U^OWz%clI1w zN=khlwZ`tx$2Qh*gI`0!@69(0=Zk`>81VRfLv)-r>YI1m|H*yTDj7JDxRiopJ?#hR zLtYT~!oeyt$>q7|` zW@7D`bEk!TC^Ere0eT{$Us(@c1eHTY%HSY4?!J7v%mW(XLc+6-I9z71^YiIy!57+`A z>+t=<6VLjAGNTj()( zWCp)Ku~FHa=SX>TyAt}zNZ$K07|RAO%~5;TKl~3UhH0YOpn(!3oM?--U6NDoDhp0> zeATa|+wfzEYq`YLRYWOJ7;T*N7I*B&+-cn|QE~smn#M}B1~!0qisfg9AtSYg_DU0c zkdq!Ch<3LbbHa4HNO+?-@}hsU&^x!`jlT=8T4fLv1^O#BWVS-SrHLPF;3+pW?F}&| zdVhlD>FS*FgCim!AgisjbYVJrC90jDg}1x;sg1Du7l#U)D($oLi!_DP0Lw;az-2m( zaP-?hJ_bI+5%)fPc+QDm5L`QndWS&Sq<<4WTNIxSNcRm}K&%)t-T*c>6*j?&sbl{8 zuUmO#UXJ*0bwl%!2?{lo&>GHlo8&tK zb_rMY8+ZxA_TXaftZ3owp@D!SVH1PFr55lAnL}DR6-Z<@6 zBPGWM;l933q23408c&slv&#yq@y0#B<;AzEyik)Jq}T%HqM!c4vZg77!|TNS9B6>q6N%4l39p++u

  • eah}idnAlmkUAny`?6^c$aXFX)`+9s5xEPSZ5%7X8oZkPmRW>k>&LlL0EcA#G&5$yu~u z0m9O19;io^WCYOH(&9O|8@JchRS~itvv7qZMtK+fRu+$w-dgF)_m5K~WRZSjB(Me$ zoV&w}yP)B!QW@gO=}z)Ja5O0sx|>X;0~}rs6<9(x$sd6(5#KIeT|w=G$^}P_3vP$W z+KZ=TAP`%PT5m1PO*p1CNuPbTY#D(gjos|owS)yj#{)|C@zbY;M;JNwf+a$S9a~B; z`NgNBy@7!mHR2&Wk=#?^ab81?K$5$s@&4+l3~}gg4(*?uWi82Pn=W0tr0tAlI=p+T z`m&@B?J~yF1}ZEX1Lu+cy3EjmDR)5X7)mG_8so8g5(XF)Y7xC;I?cI=+6KwWfV9(N zkeC4R3a5$G?~7rqBTce%jO8JR?wD|9B(p-XPuB0zL&!Y*5c^z>qb9Uf9*du0dSJyk z%tB%2cTG665oXdBH%@uUd97us^fWisIO;BuNpm3zo!G#0SAJEE-o3j8TUtqp!bHGa zA&+5j&t3$HtSKx_atsWEECa7$>*5c2q@@z{=`AFVZD4Xh_(FZNfu8{Z2+r=U8W2bN z*-dBT1^XoB;OjNAKkq>)FiVt(SANYoQQxQ<@nkymdmaH_P!Knr+nSGfM3&hbni3o# zEoe`uGxtauTHH5uZ+I$z8`UAt=)|VocGP#fggb~cmP%_O4AT4_AGK-n^Lh78#EF}< z0^;uZgdbFR)Zw>R|DdbvdHS@0{2XM|==%?s!<3R+Qyw3Km#BZFQ^Xd8rerr8$sTR! zo5?uc3`oA%P|AMt?jcYfg`7^mE+R6rEn8##`t^mMk8Ajsgc&X|HqPoorAf4aUnsvJ z9LLe4*TMd0*#`DYlabCmYQr?$!J|iSDl=q2a5bTC@;Z3kch}bs49#7f>rRk7>@qN* zmlV!mV(%O(h(^K&B~N6{-Z@<6_SM+z2Cfrr@n|`mnL@ovkg<}}V?OOnyVP;Uj2Sfu z;#~I^P5hj<%_POt{%&}oOzfRB7cLe6pTq*OOLZkE9PROc7PcuL6keO`I0pv@OKv&o zdRuqB^2>X+RY_6cspZ~bRe|z@zTJ)@T@L9rSI4^YRz#PiOm?+60?_pvPQJ

    jU?` z^1Cv2SE%o(D-zZE=2Z&a3J#t8Ohx$*+eE9EM%-mx;56(9hRb_?p*Qax8#a2>4vfVm zO3DioU&+Cv?|lUlmoZDQ&tIYQB>9ilzMqjxiv_Xf2wsw@6)|6S#=_-AogrGh9(RBB_#`VUEd6m9Vj zoG!NBAo>zk(-vrVkwO$Erep&IgiSB=PlxJYIe^iipR_N~oV)KITLn?tPv|MAd)J|G z(KgJXhx03aXSZ81wIVzS{WS3AW*)Pd64z3Sd?%Dv=7(bdc(>TX*l?pjGmfdq9>D+mp~C$px@HiJnqn zP>NF4C!4PF`#M)Cx+WlQRRxKp-w{x~{3rsJQ6G1H<`M5J6TSIZKa&}d+*nC6m63~Y zrKALWe34+wh?lPS@87q`IzOC3=^@Ae30so|%e3aAbpdaPL+#b6*w7YU{}Xb^8QhHl zr6Vjy5$oc}t!}Zy9Eo72zb@^d)_KTawI>kuBKiCimMl5-flhTW+)%4%hs!W&;wVOM z(rV}H&EON>4thbn@oQXwB2y#D`m+1^z>}I8UEQArWriHtcC4jHZ~88b-^d$V z@EM_nc#rB73@XAz*Q|lrdtq+SmOsD=M_Ho`pm;R%C+UmEsvFtJO`^XVSdo!n3b9d? zs@Nc!8r4rmwaCXO1}8EqL2cT`u%7)%n#tTHw}X`QhHApp4^(LZ{Z!E~kv(zAO;~>V zb%yxt8zGMt)igF%^+=*jH&=pQ&?6PWk<0)Dp4iqr6d>H>zUbh9WG0`Tw;6(_0<<_9 zI%~!Zv0#{{G5|?vhW(s)mPAr_>^)o0gc3jw$Jh1uzd@tbjAxJXRp;VvJEx^DuT+7F z9Y(=Zt4$rMQ8+ko3nXUsb4CNpPu#2{6(ZDYaoRN$Rby=icn z9r)7K&pQQ!WC$)YHxjJ0J)g$ugn`ZI(RGA+hEmWDDWwd0GJj))i#rFwm^OeOeum!5 z8g)JjW>ZmG(2cn{bvWY5GTHo_YH3M=D9eA!QtjHK|8d^Zv;#cW}jx(n0Id)X27 zLY93?ekEh4OgE1+j;sdXc(U%?q)gN`eATNM+X5SzFCW!w*su{J9EcHQxk(7_EVFjR z_jk>2EJjPxB!5>UbkClaJTM`>;+tX`ntE&``L!rEg-ONR=x7vCx9)@z0s(&^xMT45 z$WQPhMSVc~LKTkY1Seb0Q{fO_DFr&7Y!t9U1fls$NbH73uKL>Jd*$1Gxzb(? zR(D73Av_rg8DENvt(HH%mhqPARyey9fato97FM)2T)u`IUS^C^vRk>W&hNY|H4V=l z_NLB9@*EWi?IVB!zHle%y6X359|dT2np<2#e_a%N9Pb<8u$I{`xP66`NSJgk9(sYE zto$Y)?*w;o@alu3M~!04mj*;Ov9O37`!l0w{nmz(2tYtJk~HF7IT(IjCgARRxTS;z zD}#|I3Afm^P{qcaGGd>6q+Zy_DVk2XXKK5qgAq=@gJ}^+R$|DI`*-EhhiRCTyeiYs z5srdC>ZIWBJ76?Y<%R)&qt6VU&~1~H73}q^w{IUdZn94WOcJ)WNFJGKFiBn}HFdPS z450brQ%;T0W>{v~a!6s|&oMyD}#`=|mwgSB9&(J~Y&Jqk8j233;H` z*6-Ax=8N4KXKfI`wW<4qCww$HfRtTSdR@~mKM-x7W{V<`$HI>l5^}Ju%i-OW26CXU zT?@aiQkTm8OHt8Kg6LPi&-U3bj4cbrV*%F5^5x67Ef47U!`)&2eA{-t4?P&kg+}7` z52_hh;}OQc?v`Kb->=_mvJlOzv;F9aXniB<@zc!eEswvS_LyQB=OWk0>dNwE8-KK7E-gP}$8DvVeKr$K z(LIPaVzBzIzR7AG6hkQSXo4-?xO(-3e@NcENIMYkn!jF^>tDj{8NB?3Il*)YUSN72 zWy|S5^DWOiT&63gSy8^r$8`H>t!se5np3`ffM+G=4A;nnhM%W6|Pf9GLU<=T`^m z$xS6r{mSE$IShJDS`8LlNubSOcP$l9N#BAYGZKjuW)cih9q+bi^TEh#kE5t zx5bBh(gZ+_BAsi6l`N$&ePVA1IhHm<$KR(A6 zs(GsjV}wYP2zTfGly)(P3FV@rf$S8Rvk-be16Zf?_JKJv=gBg|bm*#qR7fyl#kbV* za=i0&(3lNaxJ|#Zrpn}#qtXE{WGu6^IYW{WNr_qoR7&2OdX7ntZ5UM5;dx{Z(Fg@e zA{o=EnH{gd-^TRaTLW*L~>+A%jktduD?4K_KT&Y&T{5Nc>Yr;UjCY zF*axFR@fk;VGA=qH8^Wi4J!dnXgf&C+si1ITu}5;Dc!DGeT3$<0PkeUNTZdMWQs7b z5|JV1&ZIW&-rXGVqKwkU{2)x`SzDM>OW3oFN$ZVf)Dwg(%O#kOv0{KbAwl!R4!=cANc?-6v|s%as)?0g_=e=;(z^3&jm; z1bJfz+omsv(&q5?^BWayyUPU}T?+lL5B8`uWFjdfb$hZ-^kn9cXJzwKESR1-g=yRD-A&*iY>|#XvO_x6Vbo zVs%JEMVJtPY|~XoX$_ZK7#}0Bn@PAcy~qv38)2`x>MRm!6Muu@6FGKRBX;%vl?p_*e}>{Q9>jg?Nx_LW!^UjHKWhK$duv|R!KlTXX|PSvc=BM-@=edKHx@?X;TFR9_zI$ zl+P?~#VFvh89sEKo{u;b;FM&RZ&@f?Q-ra^Sou8;ojve^;`g|~98(&~lh<7SEl-s~ z`SAAQE0V8LOa^}XdDv94J{kPC^hZDFVcltW9+`4>tgHiF!P)nt>T=6;a@BD|*?*O~ zl+0p#Keb??zw2Riak^y@pZ(pnsm7E2gZ7R*>g23vKv65eD`iAxbYpEjJynjVXpne< z(PUBUpbDDAYaZANG1p$ki61;*Nv0x{a}oWx>eRUPf3&@KT+jRa$DiXE$BrZ;qpa*G zqsXdsC{$*Sj55krqEL>JBoT!|QHmtVmXejMq@^;ln-VItfA_1-`FuX#e}4b`Za;r~ zKi|{){eF$-xSrQ|T#v`xhX$&#t=fvh6B&lk=Py4IXw|1z#_ZHs{WyCl-B6J?3Qrv7 zLS{2m_svm`Q*LRy{=l#C!h3>q9Xxq_C};2>I>3b7Gk;nWk`R5b-MA5!(5K*0YtpMU z%%6U3^1da>{uZK-CLk0Ikqjra?pI^}VttFQYH>?#H|yaV6Gq*^UIkAM9wP_Hkt1_L z=M*}yU8E?JTK?`9Hl`l*Zt^Z<3Pik@t#*>HlO7fJ00~fdeTf|e@oaX=v@2%?vAxj2 zn4KD@ANQ5qEX{i3S2$EX#-myfKF|}-kqJEm`oapv7JhUbKW9^E+RO}ybyGuo^Ff42 zRbS359XMJJ3brQEhuXS66SbkJAbc^(*J?sgx9yKMJ#WEakw4ANNbTrnV_OsR-J#f# z4&UDS%_-^Me@Y5448qCIbX5B2v>hvhC~9o$I!vVHCuqi;4{f_Mk>b6PGhiY$gt?0& z=k%vDI|IrgPdN|SG+35EHKf;C05A`G4&rnLC{`>h0G(i2qx^X{-c_i%;w7SdVM@x^ z@*c;T3oZ9>w7XJ68#`5%Yh$x#+pbg=qzB%$S;4NS*M%^(He7G0Xd3w++>~kjvF@OB z{ALBvkgE@h7eXJ_&sNsv=uJN|mvs1cY>EeG52XN$@LoDOWYEzuS~Ud=)}+Z(zexYU5Bk}+Xajy zd6gR4%Io*qv!jhtp`ewDOQAX+L&EW0s!TDYX}zkkV!<(4q!5@5{1$u(F*QrQ^ZQ5I;Ic2jQuC9bZ?;yX zUJZ?g?+BmGPIY_1=KRLQPdsKANjwu))uMkFNkg%(1Q{4=t})c4Vt_MA5#s>8x%pb| zg%qUe$?$%$qR4U|_6$Do4>$|5_?wi&5~B6K+qu7lA%Sx;cydsNtT1R!Y>X4MK^sV! z3ZH^4l_5fxwx_b8nGpOhyJ9)+zAH!CeGj1)px}GZ&)|!j(a6aTJuipY#HBI~E%~;O zPV*f*TwR4bfFJb($^YEVr>UNI#XaAi^S11O?HCE%eSGQI2s^{@tDRbZ2#dIBRuDeC z_<(Nqq?bRA>IWZ_yZEmXD8jlv21D@<_T0$xqj^u}p`_1UnedwxUvxuLrp}r_|M>lg zv!iTNdtc6SA2Ry-Fh#Ar1+4^7Xt2~`b2WJ|`b|SqZF(=q1ki#Ok!qtf8;$|yXgiKI z8{Y_AJLH4fPTdYu#ao*CUiUi3%xc2r(3IX6@H==_+gIyvV(J!LkR>fh^fB%a#7{zl zdr!w=N!~M^T8p<1Rn%xCK5ojTh516^?QkQ}V7IcM>KRE|{JRnIU_g}~VvvNDf6_5S zFZvC^`kMyLI9mA(cR@Qp^jXVekYI46`{?XZbA&N4bXK^l(HvpMpelEOKPG)FPT=^~ z`eC@<-d;q4&)dABT{=m zEBD8@Z)E7V@<_Wl!a>#WW=4|g7C}=MV|RMmgJ!&`kbNsJwDQ90a|E@*{vvNg_$4I7 z90^{c^#=ClOsY-g=pSE3Mu*sK?(*x{2sQOxzJ1y&j(5}O)5fz~U~dDHv>B#xBeYxY zxt^@)x^`M4^Mppz_gv9$rRl8_=z7vEG}~2wyS{Z*_18%^Efm}Tcu{oY`OVSkpZ40R zCPbx2b?&=tbj9qVIco593B9XUh3a=wW6~a!q&d+Vl&ascBT^xO)4e`_=^l~YMQ>AY zor$}o8VqTkqH2)-rIEwN>i4+20HJ_`exbs{O2REyrNE+Y5ky8!YAh! zRN%h?^+IuXIle%8Nk{j(rmt(+w!P>tAZ$c`H5fL^CMM}Eyh5Mdty)adpR_Tm({cNK zeM2){E&b`fTKawW7pcy}Iv8l>ok4Umi1nVN=?)LfM!$$MT#@$xlSC*59qhN4QUd}Q z%4~-wjxPv*g(*wE%Q8Bww>gSC4ki_2^86`fF0Gj9Rrn0r*FWSEhYv!zgSTEpHc23A zBKz1*Mc1JpiLw4VhP&|GNR5?1z8v?16vooEr}mvP0a}8K+gdd!0|w}AwX%$+3$dQ# zalrYa1AIRsRn2ohqRm-rb8Pp(_wUXR`O~CtLMkf`XbQ3&W@IJ0fT=+UBU0KXrH_i| z97@z_()=}~iZ%{u_fN&e#YvC^!TsYvgWW<#qGK2;D%|$Jt2xyh^?KE%65SsdmuDXj zot^SHAt8<&D!k+eu~x=MhMLH2*ox$K&`F4tlgqDWzxJ?UGus6;5;w8>&%RYb%(Wv*hxIM)-o5+MwFY)-<|j^7B+yKvWfv=N7uT-j)`52J z0GiZw485c*;P-lPdBCOA2RB`6^aOII<2qKFrAWV_b3nM>-g`!QF=Taru}Y zPHHk_;$LbO-&ERQYJQ2OSqexGbW_`||a{uBD9k0R?vGLXl7g8wWy4))Z?*qW+=GDi%!;iPRE z0z)@Av7@>YZ&&DF-B1ff(D)YdDPyxAr|7SCKa+jT`uNFhCT?7t9iVi1WSx8Mgl57-QxYxKzO~WdYjL5jQCMZ{>;l-s% zY6<<$0_+am-OX;yyryBTOy3PhkuGK25N(uY+s%m{ZPNGU``pl7%@t^cipTN{(7E3< z?*9GvG^6oOO3N&zFqr$L`p=W?oNl~X`pL170)VLRIg__E+4N^Dv=%1eYfY6b9IwdZ z<~DBF7SD(rOr184^&pGfL`3(RkUvDuaz0+`$z5stFYmuNT%ux=)MQrfJ|CYR$cb7n zS+Z>5%W`R&q`b=QQ%R((n?6b3DUu|wWcnBrmSdf|^oA3h!IecWHEj?Yyks^>NPSnj zKB*@F70`Myxhey_N#v9JLY#Ff8#rxxe?`;V|GfY7Rn12!KzO~oWYsJRwTgt;-0-Nf z-y)~8k}s2vD}-{6fjbde4%NGFFHJ4h4>64Rd4y52wBJ6CtXwlzH6I;cYiAdX%%lBf zwHXiuJU6P@@rf3Dt#5(z*6`Ga6NfVYC;NU*+${taNs5-a?E2)_Lz6f${W9*)X}axp zT<{LR8X$057bjYo28V>GK}3t*k&}C1@2Kd+W$zr(D_-}HA}9r3cS02O3r;ib&!fA_ z!o9cKIZieS#}=wrYG&V^i~vI6>|9?9-gt7h*q23z(E@JQu8H zX7=3-y-GXi>*eajTA9mO<yF>_#kjUIKQ7e8b2&*hF>@%q1gGa%W+wI}j>AD_j9-!27hZ5aic37d z!@NLminT5I#lRb{*VV)QuAf|_@Sq}ECAR9vk&?@%tw~MzI5!$oE3H=hQZ03f`;`E% zz}~Mmbm$D`;-Ox&shL2w*bxTt8_toTniTun+CL++d9vV(DZ;s+#=2z^o zjJWYR@iDUwN=q4Ss|)GgpUWnS%(W4{J#{B$SL>ogNCn(Hk9WunIJBrJ_~0~-^wiyF zQD2>^)PUSYYko;Zx6{O)`x_S(FZ}rNqi7WHT;s9}TI62MZfoFG?l^)ooxEe4+PjQY z^y&Q1)GTryqtV~DOFzKNwLM&}w${!HRtdX7bL}AtfF3?bG z!$LFvfMbm83-3B$_Od&zterzh+26%rBSvO0d2)-Zk7x~2dp+O9vXgH>4yb}X_=ad( zxswff+Dcb)ro zGMH4P304O3{gTSrJKX%=%8~tJ@fpk5KsFm%-?~n%pJNZ%67UNzs22BFjV~ zl;;s!`3?WRA!$$s$uQfm*I;ezS@fI65;93i0L3;GNb7AL{qn(IGudA8kl-M=wFft37z$deOPu?O!x(@SA0|I++nMPckP?%C<|{n9xPn z!5B}#O!1##hb}>X!c-sRa+eJl?^RZ5{d93-a-*zpqoVY=brPyIKYJIHJ@k7#)mm?~ z??s>IoX2%%)!%aUs(>zvd5%D@3rzvYIZJqWM70WT6Mg(lsf2xK_Do( zqWIj|oI}hi3?IyaAOG_n$SDqs(R05+N-+Ulyk-cPiyVRyzjUemw)4{NF1b~DJcIKN668&W% zp3uY_7q*`GnrYZ1)D%F#ln2WOV zi2SPS`(x+DdDSY+o?O&Pv}|7)r&Syu(zR&Bqo~aRCrCDQ#ly4rSn|N4rK>k>X*%Qn zF0m{QnL0<`;&$BwJn(zG@7Hp5SB+F<{2AYdz7@qzE+M~v+0Xh3DJ4p>m7tZSr98G^|&Kz~%1ld0c09m$fC%g_wl34Hlv4`}U#3 z;K9iimtDvVQvK79@2d9IS=}A|;Q)=ecNh0?CF#I6T@yKa;K8$rK2EVe5D{)4J7kcS zvuoS&=oVYsZU|_%^QWUGV!r3IcNsnYXF=FX=gii8N-L#Pjac z?%KXbLZP`zs_Hm}P7ikRgUrm?_PZ^nIdsS=H%k#>?#bcafJ_cIW>-fP#RUkly?R z{zF!h^K1SYOz6E20H>vM*&>#HoFB5s*gJ#hH*U8J1^{hW2Wf%1_BosGUu)b>HXr;m z2Q|9}c6OUx;(8R=^$5s`6PCDXH4W3c7GGwM%Ti0Kh`A5BC4PpIYS0*26wAWG{E^T4&|zG>3n> z?5!1QZ)T8Pd|<>J8x*c5y-ZChUe~C^#hC%J*0JvFFaU+)OLyzLz6EBBUtYWYM;9SM zE=r=bJBb1hw2~CY@oe{<)n3qBl*q%S?d#l3b%`8=u0_X?G3w%7T1gPl5b!+X^8w&W z;kit7;A@wIv8tGBj7L2@&GJC?)Ay-m;1#yvovD8%JF5)-9e_*S=~9Hy4sIkKQCB)t zY(^tP#d*Q@h*cB-epJq1R%#vd<<@~EmNEOwqBuy}lV0aScPMhFuQQM*HuzB;?H_Q% zT5lMj`Pb)0+CYA~Zk~Q(m80+IyXAlvME0`Tbc%XXlEfh_ud%+_?%=^EIm+}k)T#te zGd(e`!7mY3`}Wu1(Gc%F(XM&aH{urTts>xw77fT#f;lhukNL)mRog!I)xEkrvo(87 z5l?3D=uP$3whTFF2zpqsyGfe<0JvvFRh^IbYZnBZ{MAHPqrt586X^wC=oK$q+3nVW zf&njYxt@4nX?`K6Z>aDkLXG)PxVZgum1AawW3UKFF!*Rqz4e)CDR&}fIv7v3g(HI4_Tzn`Y%M8CxB(D>1>tOuHdlZQxrn-zSom33Sl2Y<}yM%BE48G zSuk+y?j#D3Ns2hM1=VCs-94wCW(Ne3xRPs9B&R^xi)+r}Y(;tGK+Nbj=NF8RE=ium z=YJ9PCY_y)P$|Fn_y7E6Hz+3k{C4h>L`^BV<*VGw*`-dF`WCG5zQ8gNq+ZSVCl`GW zNw*CC{Id*vBMJ;?t8NQJ;R^I8`%8&)@shh3!OSC}3IgazQfwwlFm=XZDm4+nBy9;^ zC^!$)xzjW+G8=p}c=G8yM;ecl*RfuY6FbBWBj!o_RAu<%%{)Z!l1IDvNJ`&e!iH{{ z%YwSg_`I1BMW7s2x!s@?9McmnG-=_ zMd}1Kd59l{qA|d&OQ{7+ipgL9C%QjUdt?zLnL|BVr=i>IoJ&o0@~8>b<`h^@x>WZW zK%G6Wkp(91+lyrpgmbGcepTO)FO7C-dx#39kewmk7wYlvShP}BZdfZj+iYUfmh-ka ztW};4U@3{QP13YHO)v;@lJrc&U>}R*iBAK+!xQD~(?QK#o7^j7_=ra0@`^pUsb6lM zOn>y)h9Efg^`R0<53N++?I`)xBgS&xWOcR8b0Ed8b)kBV7%{ung;Fjrwovic7|_2& zfj)?JSHu3CHS2KC87ZA;w>)Qy*t>C?f!%IgyOz5$_!h{v>E-Q$cNu$SITgb*x{DO(%)ZY>sfL(@ z?e~zBE1T(L6nGdx6|na6GB2K5?4XtU=x1XC_UH-UicNyF!EQ-TWpjB+(i2Ztr?4`R zLy~_myd~*aWxr5`uau<<7I$4rkQ<{A)odFT@$}Rmb{0+z zoxN!8yY=d?j$KL-^gg+3`sv+wc)9g`U*gay+60f5P=Df)@N1w9y6cSdqRn;ckbxFA zjYz}`XCC9Z9FCwsb(JTi>lpLFSouq0(VA@o`OwBs%SMgmK5FMd|3*!$phTv?h<2tm zM~`2|j_bXhN^{A_0DM7U6c@ku76*(4ogjwMKgF_RY8)zbT!cADv2l>+p zb~v}SAak#+jiNENtLK$VMq_-QATr5u*hhcq($fPh&*jocMh%fCItX;=&x$8=rU5rT zOo5KvZ8A$6zAtivwp4&p<>{n|s534!^Ti9o>sm$X)+<~(sJ~8`IHc`{3$G7Qp^`-U z`Y(SatT<)~#R=s!No;;YNJ)4WN#pXOlg(FcSMJ`Y&qTYh;XrK{l2_$vCsH%Gc4mlI zL$UM4u~u`CCGeeota#eFzh>L*dEjCTpBsrkrXo@Z@-POd_THmENTh}nns-2`zB!d} zl&yp1$!RYkEh-hSZ7F6hdfY8#U{YpAh5#p#H_o7`<7C;fT=-g>!REvquAt2YS_<^Z zdqBOkQ?iow9t)QxpZp_h^G^==iJy=}uiLg100az?ee>=nKO9*aM;O5?pw59ckk{QYc9||D5{w!AHa0dyeBS(S8WcIJw-eh;io#Vj?Rp24Vo%19jyD+y6*uXH!LN!qUd7J#aB%) zfFN^0A2PHkN58~Hs;<|k2Y;umVR-T95e(omhrXsPv3_(x+k3QLvwHO{vYlK#5;D#c zz_EDR2mgSFdFy-F3V`CRYjZ+IZ7uQ>sa%VY4zi>7^pMszX%Vb2#bI@rWei1>9=J4Z z6T#s3RM%XzCPubpuE?qpe4)`z8FzF30UTEY9?*@HOL70`kZk|U)2k|d-%gTufrgk>n{<4#W;m}L&Fnr^tT z_5)5Zt?+$xCPo?<7}%_==&WYG7|j5wSbujOBI`SA713R!L!kvAwOyWO%cb5~(y;`P zVjQ$rOz=?r^G&3qW%B7uVs$j)98WTJUx9F)U)5G&Wz03W^gN&oB1^T7m|x{&Ah2z_zTa>ktf2O&wJp7S{6HjZ>DoiS%=gP>^jdbt&8i8>hCF1PKc<+IdX3D?g9u9Z9tA={L2OM%~iRKV{>X&DXEPL0c|afx<1zW4Eo4C-;AaC&MAQI z>yB9Oi8L&Twxg|WqaAWm^K(L=RK1!?JCpFUh4i7&j(Z%U;=mwpcZfTsMZ#oL_TwQi zH(=u>7tFu3rQxGjG5hSKh67_t_UHCCn87+VX^~QCjy|D~@`@3sROeSd;8a&f?Rjw5 zGgD_}@qpqFDOYXa_{nGHAZ)vQk;*^pe^ShH?^YDosekys&-5aJS+T&@frU$Y)#M@P za?0l~j*hZQ0LjVI0%qCmR|ErEw&g8wK_^1mDR&RtpY)^GpUI=bwax3@!t)0Fl6r63 zNf;N=Bi}`Tis(Y(FWjY)8MaHuCM$Xbq?Ua+!eH`vzq5<;7;z*fIG%|1OTQDF7?_(8 z8MJuD=@w0z+<{I;onz8xD%n@h$@gsJAgCk#zleFp1G6}K#^VqiB9aQCs@!*3N?PO zBF}f_e7uhZsdcZi94FY-H%|PJogL!qC*OAlHf8dq>x_NiD?~9`JRMMP-&otCHQ!?@ zGb8-gYy#bOYsOj(};T@7EP? z)$t1bURoN*>YIwYpW3msV&cjlR72_7Mkp$387iK%v;He^EUh<}GyP0gw;5P|_ zih7)*z|X23guN93o@ewSC7CwS#+k%no-T8+UFINsD9H_MG2#w@?E`pis7CK&0<>ad zDvCjUD~s5p_uWR+X(mv3=$Uo1|2d2snh}*V9a>uu zOQc9C)_mBiHEn#OIAlM-9OC`mxed5)!6?RN^bLg*1~$)$j|#Hr<3a_?Wahv%pClbD zjSjG>>pU6)QN?b-?yl#PU9fl{-R&JvLm(Kr;^olLj)mfuxdwi!-bB}cprk%E^#G)B zmjiz*pf6GKMAJozoGfjF2^v#rf{O%qR_pXFT+J)RQ>Qa!sWZB!=yf1VLuFl=m2{N7 zFSS8I)!2Lg#sX|iqWRE;Q~LUGu#JmSH?keOJimL8&{n4a#3R<5qI{I9r+SuCL4U?t z3q`|z@N1eU<{7kw#f6_}tlut@31C_2bLQIEHa0jiWt)8Eo+c)iCH$slF^cK^wIZ(&lRz#qi@@-%Snk#WW$ZBSv}KGAZ&D{yJ58I|MTGmznzxMw10~hpX}%n; zZmB0rt2cq3JioD*eTT*WVgY`xq|bXM&qp3j`oC*s{mw+5KYt67Vok6fHm>f(_0pKr z5B3pp-#(naFL|_Dc~_Q5675n+rqirYg=We`<2yMunpu*LD$OuX7sM3ANGzW)tzjjXZ`NKMO+b6$RGb$)rs$Q0E;QSFwZ*OPsSza zq_=6-hVk7ZvT2BGhzELNwwsYErJs0w$sIMG{hk44lyf$R>F)T7o;JZe#| z#uNP`qMjVzR>|XCZsID-SA?EfrwMl26THvF>l-~f^qhQDz{l_UHot+I4PM47ou!K=$vItK{t}n9G zSie_Oi{dl!Pa_OxUz%KdYfYl{T;%m^*6;16_MtX&(K2D$+dschxLk;V5Jlvqy+!ma zBG%O3uG{sFg-n&}g zx&b^}AhBQ8|54V9piw#uq+uHizsGqKwm`s0W?IU7!lAB^pUL_3(IAR^PFsrH0FVzQ z?KQzbcK5R^FQz96zV`uev~4HmQmOid^B0G!O}NqUz{`65$=6AtlU585%h8+}5Sxy6 zrd21@K~1Y+l6FndFv`nHKI~Sa(*j84JD?<5P6_lUe85Az&*+F=dQj&h zZXS_R>=;PAU!W(&WOcc|mF~9beK) zMVvph{2QJ1-yk1bNb6qD$gfXR@IbaP(~UC?UQ&eM|AP#>m-M;>bqR3$cu&$TmVTj* zASKTuvo2_9L&v*Ge3jJMEfCk=0%|JbyU)n#D?&sp?XT>@50}#|hdi1w0NlRF@j>VG z2T@Jh-dov;rzWdz>mp5;*fm3xi16}l zTfbbYWWZ#K9I(u=&{{2O`-rGy0nUby5Gc(FeQ*m;L(~V6M~mF>ZQ=AQHul8y_#MKf zM#y9MpyRQ@-@ksnMII4})HTuH?sq(sx(V_v2~22Tcyu!MQn$hDYARL3?%N8-%0K#Y zyV5N;CcGUzZ#Uc*-|!sY@=mh3`}8fBQV5}nJ|JOA3Cq6*L>M4lq@ua-ciLX|%1}`O zd>fAsf7_+HIH1zX8?m^z@y`NrC~E1zuwc?WFnvZ6oo&;rW5a-tb*Ri~qxdkJG|OEM z&*B=O!n-{6^;;)dxro?;=)u4m?#2>2^SyNtBqb!vJm84H$j0~HmS*vc+~McfH`)Q3 z{gobDz36MZ5Ds!_fsz;j9Mc5a`I?s(U+xo;J@BtqZdjGSjY{sPoAi!-hHJ zm*%D3uu74ckAul>dZ;5tP>9(vg7=XO5Y5_c)texACDITJFYHd_SEgmQJ;#CZD>rDb za981z{`p|?-oZ^U*rXG`;fHqB(n@{%(-CnMQNR?;8Ar;yoSY@k0wSi166~)Jdy5N9 z_q2NnO_G<32kvXrOq1D$kV*hZX~l8niZtOi7rjdtE{wxbD6W}H^7Ps9(E)BHe~VK+ zYN&ZJ)$Lk?NEhK60C}Xv1E+0kjprq~(}M_eU zHA&lBNXlr|?MmYXs@=Ca6&GCS+oT=j2@P4G_c3K_zP=rcHHe<3ozIEII9ahK5f|3s zymuo6=rG{CXKL8YmSTr#pE%Fbq>bm`8AVf;x@>!#HC@R4JfO#p%-+@bR!x;>3@kd* zS9*z3%e{H4DL|ayQxQSpq}lPtxuQ9h2(3r?Nfy*bHj6{aAl$`qZq>TA+53uUDqA?G zTSz~51#ra)n)~8%*g5~Z&4Wo1NRy8UiAdiiK4q|zGWaKHrIeG1{w!(dZ9ng=lc)6J zv+IP?=O#i4i0B-;_D%#!`QSbK;2vGhe?!j`6n=QJe)ny3@|_^eDs%CO?4Y7(v(vw# zm{v3V&s!K5;?X_BvNi-U1ES?gZ_btFnOpI4+rTouB`hpVW&qj}afj)sQWPl*PP<93 z=3IhRjAmv*AJfq>pCF!ws6BtTDHv~=fN*4c!bQ@zt3qvWLYO0c=(0M96h82T+sQV{ zg<(ypy)Y1EdV!Xb;oP-;Axzr(goIy-06)-62m@dC{LzZ_HeC3vkd%qYT1%9RXRndA z`6i`I72JK=WWP7kEBkHcZI0BPVE?+(D@^KSYiz8`u7DL%yP~+X>04a8YimC>8M-R# z$)VZ)LM0-$m0>u)=&>?AP)*_T>rw3}9>APcF*}GARCkXxdr%w%Fqmf>Fj;S3!K0dkTNlJwG(N#ZQZr}j% zqN@U9`vD`0ii07Y)PfHKa0zoYs%XElRmHTiqnmLZS>>>E5ns7;oue2Np;$#5RBr%Q zh86pkc1^D19FkpFSXjI)i=-yON167#WY^=N!fD1KuL+auNT{#GDBxJrow%LG>Q^+& zlQx0p1kw3F1`dp<^vikMRx!qnHew0(qD!C8dGP4boxkfhs{Wxnw8;5@msdU9f0%=$ z_E7e-BUwoFC@g%Y_bz!t6#uVGEnPC7UtVoV>%6=mV+ET1n9J)G%a?zTsNlUSuNGMP z@pQJC7kZ$d@rS~2lupU0h*-QjO@=fM>RVWJppjJCiOGR-oOWiSUC{)hN9GyDRV~|8 zw}CC&noA{#&c0?Hhc*Li`D;!B1O0)-2YAj~401rIgOu{@+O^)L(6)LJTh-{`~1v5j_DxgVw<1+IXBIm+hbX9f``iHQMW zDiS?He8WO~%L=*XSlm;Hn?M!9w@c)~y4b_rhjIlClY{O0*9N?I628?bE2SCbFJ68#QUkp|#5jUk0D@9$5w{oIPY8#^F-Kw)SAbc(p#>5N2^YZpYGn8vPO@1>gK1y!UO=yKLUS z@#=38hX{@T>eXXNWxdAjiLfOlECf=Tx8u(XEgKafUZDFh!Zv4H2-a@LD9B+z-zJRMBdC`%NL)*Hfl{ zqh{WQ)}{MGAVN385?j5Z%FX?^wm}Z%_l6D0BAMa7is&M| zMM|eyz)l@%R+%ufP6f!ZA0HOY;o%D|DTV)MUMWL$y5J=z(n^X{7XeGA?*h% zmw#26+^laCHgq`hFTuIE%lG4}|B)Y2RD8Hq4)&q{3Ux_h!((sNzohzV`x^*SX!*Zt zcRZS6`z{r=Swgu|_ojlv%lMZnS?)bp@5jHAKyF6L0uPMz{YQ^Teqp`o^S@ewieCK$ zQ)e~gAMHfB_k&;m|GjuVRki<%Cusb?Z!#&Vo?DZ<{~cy89lVa6*REZwtl12}@^2eA zqoM!%?AI>*zYn_)17|zbyasqQjUOtXbC}lQwJ6lTpcpmIy zNUqAiSBm07|Gr?|LbGk5+Y&=gO-?H#5s>woYOwO^Bvb1uE%N2pQ|&SwN$j7d+3xR} z05E_@j~;F2#9*1RC4MLT>aWw#3$4h6tu+41&8!uqVnC}EDA2&)Y9plV^KkV4=U-dn zgatW-QO7Y)4Xb=v*7vtIXv(86(XKOfhb zfd84WR8tu7L;;&&-LLfD4~&$v1M9bH>-yixWUB7R|MyAiX6b((RrdRTmrC7+{+~x> z$$taHSND`||Jb7c-jBck-<9+KZCq+i*?*sg`EB3jzc;HhM*s6PVr~9s5-yxqQr&WT zP{{DlsIpoKVXSoQa?Aow-%7A1#NY_{>60sX_wl2UP;gG)MeYyGfgx2?g9+);gx4r!JojX>6|bV z)p62huqSROQOk*;&*h1UKc;WOUz95&-X;Db1Ui7^b^vn(wpP4M^wxn#_wk``8k9Fi zqIvuF?eWy8paDDF@l(Y|Tn3uq4V;{G2!tDvCIL=)n?gK}hW4(U<%+_$0@7S<)F@SX z9y%TXo(U@-W+(-u|&5T+;N!@v@*gr?`>HxBOyYH`Q ze+5UM`iruvDGA`H9W)0DB%geT=)VyXlNqZm$$zsKq9q^zaim9AL(Qi!MuwXRkh}-F zBp-FW;qYxdBYeKN@*)ZRkB_ZS#2LuL3f^?$pQTH{^vPF=W<}!&w$4lhnd_or;Eg+jI7g_H>mcCvm`lVZj94&mw#@%O5gDw8L%{x zdGd96T|?P5r@9Inyjv__*5TvWFF3Y?7cSgP^9#BqYLcNM2o?GX=tUveNV2Orf~#|v z=mMCNFj9--PR>p}9s?ua3GG!&mNX)Ifdobk+Qxq*F~9vLr;5#X zj+3~;yJrf?Kn!>ItRgp~Dd#Aivd?lPe}U`PhSiU%Vh`x}gg|bV#rI`b--FOmrZf{t zn!6J?7#CC5oHY&(u6XGXH6CW25v&>-pKS@0Ow##$mFn4Ean3naH;|qB+vN=bmehL4 zI>AD|1Ns`V<>EyBd=X^_qhSi+=|!t{UHT0E zN)rF_=o}l}m;gGf#}R?n-GTz^c37;*aDES10R^we(Z4Fr{dIl6O8}be`fOpg-p^OD zMV~)Qj0Y1|x8cK8@LfA?8Spo|?YCvsjsd#-?klxwlvg_9nvR2vlhp65;vTEZzAvJr zCoR9jX)>Z63>W?zR|67goM5Nm=jLwQZU3wQz0FDrW%%jEiLI-Sl7p-+%u$$16=>1h zOL52zL_`DKKukPgNZG{VQ*6vfgfKpYh`aR+Gyp|vkS;K#uG`G%HCPC5N-urL)+lav zPj=tE+nl*r{MGL#d}`GDBQQno{OUDrNWP;sUNHl*zs>R|W#c}QR%AqbmMi%H92|9^ zs1_o%+en-8(dLkB#3LppMJgGXdefTN2Td7M0Cp#co_xwJKcH&E3DleSY_r*xcD+YX z3j{~ML6~{WY2D=UWpR5(~3Rq^XrZ&ubs^pkSU&AD0@wXG$Cok@`gT50f0NbST! z_9k`2*O9a8BlI?8O9G`asw7l{kaJT{E4G3`yqQwmSM7P^k$M0_yT1C}*1AhVck$Ku zJ`{rI^hXzv+Hk2D!OUwfpT%{P`%9F$#0?91@K+8c%4**wN2q$$@7e!43E#Qx5wkFh zTo`9I+FVChw=aR2IrIeQNF2E9X@|;KlNL88{~sheOyw^i_no!N-QA2y=~13~ay{yt zi*S0}NlPxva-&%~Xkx;fXh(A%oR=d#kZ6rXy_!2q|BTJ^vzbG*U7%8sVJ(ZTAX5J{ zX>$bWC{yrgXJG>VrdRVDeNT@dbobe^WsBLopQnn7VVY3hY@7TztPCHr{iW>C)vjk9 zDKhow+I7rqFi!gq9_(43%Frs!oA+kP+MY0z$%4vp)>R$Hj0+#%MmAb(xs zkOJe6?Jj4}<#(SVU26&lY_spj=VWx@*!&cK6ee_II*LeK~xUHa{{mnl#CrDD*U0fl$k0$^zpWRgL&;9AbEX zK@JB?54=U1Brf#%*+Cv2shnR$Gc#R}6xs@@e$d}9mj$KAB_0isa?BM z*yqq66{Yak&!E=mO*>a(-3c^UA#D@R`cpaSIQLl>PNrb0eUHfNZujq1>sJ^xltfEn zE07_5z9r^8J8^eN;?TgI#~8|UP>qedoKr@~*IT9*o4DQo)N&F%C%EG2=Mk48BWtHa zROd=UC-$IYilc7Q-*+8d%O(Bzd{aF*wa*FsEE3+=%h>~I6f7Cpw1Y#|bStZ^%zc9v+i^51eq5*UzIa#m|mh z*ns13C-CH^oV;eAhBK}D7GL}QW#8^!EP%a>WkF2YNS2d6=|(9Z3IWgM58_)=vD*SJ z#k4%|0t~wM!0ddLQY04OOle=}*}hWFx(x=0l;mATO&(P;RpKK+4&I{k$kQ*@P!gJ5 zJVofcDPBwkRHjd_KkCaqA3ml(=zX~Lcl|S} z6Z)dDB>-+}DyqQB?vhEGHexb-heVr=D00hkb?DTelHI(wr0wa}5cyl)f<9rgkNQMf z(wU#;8Ppc(hT8xx-oAVHH}&;*BpR8Xi0n}#M~ZlIn&)`^7jaU=^mjc*ALayzTi(M% zphv53EGGx-K{`w54h$p8 zs3@PHWGk?(hnsUJnV7_JS$Ce4&6RmN{%Kel4wc$~0Z&iI_f*q_f2|3~00H!u9@_Q2 zw|gzbF+RQ#V6{QzMJnwDPQeaA{<#^a^b&6?@LLPnobZL-rYg*E^0Je?KOYt_(sPs=F-F2EB(fXQOmFJ>r+)W4cqK{v zuyS;5kF27Xeu>uSO(YvY=lo8sblSRBxdF_wkpI=I-%zE{?&5OivQO8p zFO&t_N~_W7^8II?$XvqP708rJjXlZp$jo9^yH=)OO&ug_8`V=K9B;!B*(ut>dDQ;X zVg2sn&L`3JKxIUk;6?Mrnl`Olx57OF6`8&?#wHWMXlERYp)~zzgJPQL zLLJitulg1Tve%I#pb+}HZ-u`?#FEaO>)HPZ+-L%m;C#G?UEgkfZByA`&|*UR7}WgO zqqn$4O9283cJrg|9|opUaXWBUPKO}C`gbg6AQe7XN~vq5b5CGD?0?~B8|N8-S&X<@ zUds`SZP;$zO7^YFh_&4zE89geujxY(M7gvqy#=5ErRS%V7ht28rk++5?wF6iXnqg` zfl#U5yQkl+z5au+GEldTUzQdVFF~anN0)_U@w>HmGa4ZIK%31OBZ%^ZacxP!=AH@r z3OG_F#ZFdww`brP)&!n5B_Oh+3R7IOP8s`dr2>&(zKzgWnS*8rfPsnT#ehMW<6_>z1@ZY+|<(CPnQJxkCbYNCgXLc}o z^XafL+ccZTww_YAsckbiSO=6xo{C1%Pipjm8^xragVA;|u$(k8VFrCXH;s{oRAHMR2SF|6c70q4JS;iVP1Y>r- z-7f!w6y*73OzI)PtKQKxATUhHJV5Sne&f%T48LFxsC|!p)bbJwRhk!u ze>FbYhSKZ?Tz4GiCP}8%*_cTQ1-U+%_C9bQq9-umbx{T-tP^1yJKp7_N<#A^)W2_A zgWIA)4Ic2_`1-cjH~t(@-DTs?_Wsj-;LAQ$3T?`GtjE z_JQ)7E;YY)W^MS;JjFCRo89j2>3LymOAP?rI&jracq2By`#fwkv3x5?=o~hC143+SFaOV{NSL#f3rz&|I zmo)0!%R<&}8eE=XJ&RXrFL~N&-IrIH)8jPAn502taSpfgslB&S4`uhvA6@-sGyl_} zoeUpvyaxW=l2r34?TePy>uN_;?kv}{rQ8$})V+ZHF_RIWyZ%`0hHrr&(*)S+>$(4| z3}w%r_uU8B_z$J3T}9^j>X||B4aNH-5;l(idJPP37`4;qSu;E1!eBLCK>L-~TR}m= zGUwydr}T*SDcVYx2CM2(S=k**YY2kCyU-U0*6JLp8Wz_clNywW96wNr^XvU^Z(U`Je zn#<%mLZ6O3xx6H?bBgKy%gCQ;uY8r^bF6Y64env_n=rYhUv3on@m0WjQ&|~jGkSC% zK|*jEkneR6H8KAx?>!!wTGP69f5f73mvNtjd(b%kT9C}D*M!bVSWy<88;5-X{~39u z>nFni=kicVvLEq$rwP|s5{Ym51>QyeiJhI-n*8*{4*|q>hPd=|Ig>6q&ve5D3xz8A1wklp?{)3)gDX$RTrFpwQ_0@+y)wDPn*%M30hqDQ z!$=TObJG+Dta$?IYL>nCG+t;TzU|tCQ9TqC#!NhF@wpNQIFxuR#OhTN8H3ZHE)MgI zyNWt4cj@^i)+JV^9$S9bUTLqi^pw1^OZ4?zGXW1WnKOKqj%)nUQ5$FA=BJamFrM&d z8vlcpS9jm4(CCPk&3^T41uy|&BzXO}uH*+;R;9r4{5xNeT!5{Ut{aFbIslRx7}InmrRS2G*u;?HBF zT^J*+j+zU%gH%|i({M{Ov5uuQDIM|j>M4&-cUz}*n~fh5h_71Ft8k7gs9I7+YwrbK z)XikDI&Us9ozSo6{_6bv_XiI(7M4$ps!6=j0bnmCwEHwOP_Arqi4outig-_~dCSaGOzm~WB$d1Z?b^W5%xjmt~_U=t*t$aEB z`ZaFFaq*Qr-zoBJ;j+7Q4E^v8Xw9Dd0^D0(Z(q(6?aM1)zbN0adzUmh0ORrXtXDVa zr`qb)fF}3MRap~8YTn{T10#!4+NC*k2CJ&iB??NV_*z9?EhNY~gUnP(qDys(y@&lv zxJL!*m!dg17+{^JV}_CU@az@ynm2j7805?PSI?ukhM72j@ejPEABUJjn@a+X$$n6= zwd33vr|Imibe0tB4Hq_Q_?Me+XAmH*eKq3{(PXkrk4{4cg)QmB!!ycgjI(`nGbTv+ zBp5^!8FIV@8=Z+E67;k2gY#7KzJpiHC2kn1s%|bq1-yA4Cv`~B$xNK!G-67&&Lewo zV@H4`rDXP55d)+U*?2mDn1>U zqL&2c&JrEN5WtpCGTUL{R;=~9a=m&*KJ)GAmt~K>5s79VQdu7&sT=fa>s5?wD)M%kwA}W z=V_y4Z>%|RYnaa3k8U)u+Ugj(GI~!nU!P_`DqN!lrt=49=)pbFy)8hJCroX)de08N?UcX!uXAG0LgY~e&j+@!~o&^eB z|CAnUt1CT>Cv%keJ7%|!Z`(lT%+tT|!ILaOTX+&ftQ3s49kO(K=+$P@F8trO*ILBv zFCiZ+$LKJIf@41jXyRc{ua)En%h?H5j3D-4GPrI{RrT@yn}5unP(8;a6&8Z{0Sdn> zC>ZTI;C^goxJ580sk}7HJDK(;`CUoj%XHJ%msZ)ML?cB7uibkf*a(%WWY&&oTDhv{ zxDi=`r5Qb8UCPZQg0WPR5feNeyWS{cN2+dI#d|@~DNM5rSbw^4h!Y%MReaE&?7Vj* zi-j=h0T-*lLCSex*oR7vIpeIVE11jU1j#Z#O-(;H3mXxfA+qI8BrcecOVrIrtXwNC z+n%4FuL)zKb@Cg2rw=!kTmj1_WHZuI<#-;<{~_r>dV$%@Q}G=NBD&|adXHCRUN8Wh zBKwd8!Wob*y!-&NP0xX~aRYCTwhdmq738=jo-v zk&{BuU>uLr@pbqqCEfV=_#1&6Rv|Iq+4p$m=ZlG#h~3x`ix@!9zTUA8g?=Q&-??kL zVDlJ~MNFCJ%1zy?E)VC_)01WRueoM2vav2M6Jlx(R-YNQI{)ddT{o+K%t<(= zo6bVLj=LnY$+bU;<$ZZI;=`9j5{IY20q`!8?PsItD6}6RIpl1%_rjc(dcVeylpm^k zR~;$jFNb$}^1AuQ)@Nz*7uiIz6nVc)EXOHqp)V#Imva4+@shOBaw=JxqB(LQZ|pa0 zm4BX!f|({jt@y%U`;HQr>e`0Ik|WT`&M~t!+}pYEbl^*HoVMr>uq$2=5o&>VlJ^Bn z8KS;H1c&U!&Cz7wMn2Mrdn)UWkbXW>SNQryy6NtD_!8%8Dq?Gf@OYMVI8SK%nx{NP zmhcc0sobtz4mY9Bl;?p|~J-w{y#|k(yLhvnE5k0Y|>k{rXc{swV^8-)HNui`B@N^V8 zKa$Dou^bOz_Lh;;2(j954;VI{;c$6K0rsojwxT^kEF?476994wYYkV|jhfd|);b12 zL1BZpIebyqE*eiH&-?Ch;7b5_X1Gp4ykDeCjWgT>ckMm2{xmXLop#uT?hDnFElV)w zFD@?GD_^1GM7sFInB3>knOoY~0iJH&tCEhtxu&c8R!v_D9BdO0bvm+m8M7o@Tz za!xNE8=6VPflkbFoR)rLWL<4vm?M3-oC&;6YNIZj=Ie~}i>9-RNls-oUdH66XntAB z`=0AF_>f9*$l19!p0}O5USSwf`%x=Ptzc1@S_c+WzD8lv)AyIjZalQ%?|W2HP%YuK z6pt>prus(RDU^Iqqp4!DEoFWC`@bnDoO{nRV0mmgCV7S@_=TM=Hn?~|`C7;hbN?v^ z`^|oTW<(lC@`!BTS1YWtY;?!h%Zwx5t+Uf37cvOd{)S694rypvXC~c5$qi5}aaonJEW)3UQBk>R93HXCC~u z_I}k1*2dhFv3g1L6NG>Xz#~#nc&o{B^oMdGuQ7M5C?kE?|2mL5^d9=sEC9@KKEX-l2CHRke$*pcUm7b~fTn*2NLV9}JU?UgRLkUL*4ITU* ze^AKlaYcRRT=l-V86Po;<0Ob}!JDqRGG&19TIt|3GiFOF91oqO+>l}a``$S>9Kz8J z!8tr5JQaYRQ@EAgox_c`bmYjKL)2_sjz*G=PHTnT8J3?nv0`xPAM+((8vb-yb7Y;7 z6*BW_;!EezMfZ0k@7KZrK>zvF-ea0cfZl5#t7-y5ZPj|Wn&bL>I;`3`*@`MhZ7ie` zy7oh-Z&VcOTo|);l$ZNkMrUyS3e32ifhn41pYGQh?sRRt-&X6&`TlbEv0*D-SZ^Md z;X&^4dS2tuwael!k)zEe9nT6`N<8;`iq2vst1J%+@H1DK{C@4W`05kBx?@%I$uIM* z2jMxJY4isR5tXN+Ar>y}r={d zLDTL--gDqRUmu^Nv+6$i%5}BJ(#S7>kY8XJ`2}wN>OTLx>WYBXPds>RKhJ%E+!^kZ z>1??@2?pUFjzc7}X9coZwO4kYxus2JrCskKl0Lu#$g=lPx?XCr0V?YiZ*IPp|5_J4 z$)U>}+|6b2l^h_Y%uL-9SG>|r_hnXe*CESq28?E*5JAk>ah(ZL#XyVz1GKjI@YS*m zraTB!N@UTNnrzInBRcgaTNhs*baDC3t0V9q@vRr}J1Xn>tV)JO_627Y8Rj8mQ!T0Tv)fu#2?De zQDtMNT=@Bmu}1 zf&5(N&V;z9)Sz~Tt#4XJR8~r66eqaf)i*^ioHmJhr3ZyQ;>XFFN4_n@s{7;Tn(s7L zj@G-SLAT^_0z%f2Hh74n+TCPy@MXldE1fa#6=VCvcD_M(fv9XMHf!~!K!z|maERG{bT&deX6QD_7CiT zk0{5#H<@6hob%cP`*^+DIqSJy;^9UL3ac9)*?sKd-BE|aQ!z99hG#&RmtyVEC?=?Pk){hvl=crPBjp^+$^zn16kIq@mxW)eP_y_oAZC^Qudwq_{x>ohd zE-T^Pt0M~v92>oDkH>ZzzuuM|tKmE<{6Hw}R?o;ZkMb*=Ove3UNwG0_^1-UzX0IIm zlP=yF{u||9nycI?->%JmVnaOt@X@gbBQ?5pyZy-OlIDS4nX~2(xLy-I>F7ptBy)I>yVq(#awxM>KxO)MNx@ez{lp)4TQ0;*iz=pLTl z@m%rZLD>llo@~tGI^pGdTaFOl1T|6WeVjdHOV;Y4^QiuH4^Du37Efhk537m5*X7~Q zm#T~(H_mv%*E19#%F4>fk#(FCzm2K$SlOkRgW>tQM!#L?dv)F_FdE+K6I!HbuKd6P zN&A^+^XW(NfHkYE+;?fxG8%zI)gs%g9vG)(c?eWlMqXG2tnX8Mv-<(jaB!fP0%AL zxDlTH_)wz#cI!;P$1K;I;I;ZTy*KjVbGucn`hSeVA1!(wd~&`t`la9F1Iv#+So~z8 zCARW1hICYUZczI^8%QKwmMk8%+=tR54O6*_ds#KyPQ>IWTQKBTNXeKI&%3>)Y znp0E=O-PD}@_f#$bwBrSd$)Ic-#?%2`EJ{NuQYtW*L9uec?|opANxVFp}k^(mGWnz zq3u<{hvLDmw3U_5C~Cug)0&|g^>q7AaTNhxPujdkz1Rs@-90Ch(58*aSl?#-zCxi} zo_748|?mbx*vQCR$fBtzM&H#ztm}MCU zE+w^mmAE8y!1#L>!Jj%?Zw39^jz`zrXi?ixTaW(enr|N_m`2dz`5-{`a&^DVt4_Mx zJgUEaYm}ZQ-$xBZ+#GI=WY?~Z>ojloe*XPr2hSdZ21OGBF37d*_#{*B;yFav(Jd1C z{v%ovKF@jO8dMNz6tVjCmTMXmQDV$<*tpd-<91H&+b_xGQQuL9nuvV?jG5A`q>orU z7d9XOuTzRvS{YvFs#E(z4R)RZPj(^rln^-l`OdNtTP#z_JhY!=S+zYo(-UG-zWdTB zA#uAD{eoDq%>^eo)y|9m+OoN!L2$r8ibV#cZOD;3+4?>3G zT6e);4aeLbUvF)aFx0}!9)CJr;{>f)v104c_FAAWb{bdW$CQ2f$>f-s%}a zx8fl6t|_`~|Nh7ax`PVB4eoM^tg_GVt?n$MM`rs@9z1L=ri&gee-WlNoc{f400J|n zH=kzqBCcJ_5mYn^U-@3rbrBY3!RF7?#n9%MlptDwuVRdeCI&3FOr7s7o7_~kFe|?Pz%{6b~n26rf%FU|AlrC5Oedt;Y zkR(qmrMgML$dpzb>Ff=i*-&3^0HeuH(ec!y)Z;Pet)U=yY*(skvBk~}TUJ@GlOAi9 zU*qeD>vMiz*gw(BE<&8_#$TV zUPiRwJR~43EiK`v%*4_3Ha}ex>G-{&o)Q%VnCk(p@2DbO@2&Rz`}fhX5f>=8zw(gq zOM2kIfX3Q%n+|*4RTTO~LJc#9=w_$rEC8DeAt9fcu~FK%Rj-fwO()N#SHC}%>@y~u z4JJOG>b=RO{>D=1VYZr#P_<{}reEiWi$MOtg3MkQJXzM~e89KE`| zr-%fF7?rPr!K;fi)gkME{hVq1PTI=qepYQ6JY-1FGs&=PhNLa##Vrrehym^*MTiin zJPzFhhiqCpZ0Lju*}kJ6AaY!QQUodI=ljFGPJ6!S?KPu0yt=c5o#>j&>v8J>ziu+# za7(R=*vyGi3(S7Jw|8ByX3PBpZDLCDWGc6p`Siu=8eicV$|}+?*Iv^Q$PU~=d8mZZ_ z*Sk;S$c+hZ6qhG`wC%g2`tqya@|MMHefKn{s&dC1HnQF58Q1k_)-^(|7tG^vmx|0T z{qv$Q3Wb*~cD~sA`|@I?d&ereY)S?eERWARv1}SuMsm6uclmRh-YnSu?8!|6K&DRG zP-(<$R{>d*0UAT>@a6{stl0jjDBNIXn1ceB= zZrbPk#e{H1jGG#6_jXzQhG5Pl*{s3B{tXK=&zJ!4<#*^lZ6%=Jve%CTmr+8VH27`^ z;q+(9Pn`Hj4L8Xfc1;K9yPWq<_eHv!>cXV;3MSE~>&Ho*J0-tJrb7KS45dp?eJhyo zr=6H>s(iC$8sXR+9d!i=4n_p&^S{_kem*Cues@`y3`em0C}_)I3pBZe8C_ea2Sjzh zOU3Y%vzv592uUXfCEqtT8dEF%o^r;k%uuuG$&u*UK}4 zZJ=lt51J)o43V8b>X#*YT7AbZn_PVPJzS#Vnce^R!-QBx)v{fz21l7g7UF21w)Sq# zDd2z&!1-K@E+A&i_)a?cXj?C_Rd2l`HGazKfKzB}0~G|9wreMsISm-PPthadHdq{r&g0QFgwI86#HWpbw@2vx+1E^fV!bDrp0no+R&XRar`LTM`6sEN84^QjfI_b&vc(Jyk=|_cS={vVqR)&OxaL%j3Chhr7 z;%W2xq2D5sibW0O-KC@!6qPq!C!_erdkz(XJ<7c2~P(GT3(%}so z54$a|?*|ji12#?Cfw}+>;k2mQzP+~<&bImDH3Ny>D$nhB?ATS?)oHEnB5y4)S~!B2$?o^ zJqboWLpG*jV1>imPKgA6>qqS@X|;wEsh#I2Wxv+#Rh#B4DWjSI!Du0G|M1$A8MPi= z3LponU%Wr)eyoMdo8Eanrd^Tk*g{2AX_~( z^2R*!JR8X@gfvX&^`0eIuD(ZGUQ5JxCe852rdm0Uy55js=W z{jQd2KKHs%=nS#$E{dVyJK}tVeZ808kqf>5 z-s3Ae+Y=|mSt%Pg+Y~GY#uOz#h=rqv4kdbbq@Td+S20LA3<%7#rsizbil5DXA|mAa z+e*?=?!9oSiv-(>a+Kj^i|Nn}(;d4?Y!wYo(?8t@Wm=;{h+PqThWI9N_2{Zc!V6Fq zSjVv@xf3M}&pJDGbky4dkfV++uiWKxjg2fr*kRrJD^rjT2>t9*3}?2!Xb^bA)M}T^ z=IlA$U6iz&U*kmE*hv&M5#`IUeOb!+$#~MkZ`9&+hmO&>OV5D1`HuyOtj*Hoz6tvI z<5EbU0{KhJZ)}dO86cHpKsqc+YZMXRQwVfe{{AmfK_?Ct6-37F1<{4diX}^rS(=^v zt>9?C>D%K&JRXImc^t5{dgX1=3&vVvJ?9=~y2NP_Qepd#1&Z?5r={Hb07Y&^tG+*xOA-gt72vqN}f* zvw-0oZz*NNr@VW5!H1KWL3vpnrOTrsH*X=@{?z;ABxnU&pC5R7OcccKn|G#NuJ(S% z4ekg&Roo9$MexSjHM_9!PtinJM}dGu6teGcQz@_kG#z) zFcE#M7DxFy(;DV^xnS++p3qKBb@rWk(P?*Cgo@c7ZRZx7d5M*^wNdP ziVR?#xvaOS&{CgS3MWrx@Q6|-FWQ$roFFmkCO7l>v3ec9Jq#XVO?g@xzb%I6F|-JD zzpGVAd<4*k9V`N-OQw4VsuR$DM#)`KLE*lmPuB>#fGU#r%3EW&1^WV&3UOUM3`1#s zozfXe2?;wR&>+;ZT0LDa^YNy~NA6-BCpp#RPb%$DEarp={cfl6Re2)Hg%u*`4aAxk zVUGe$4RL_& zE&-kn0r4{?PMq02GzeGy!rH1j+os`l9)y$Vk~?(^aead7H5%UrYey$o!>ut>K4?m$ z)M}U|BD4^P5lWte(faqS{4N}sWVdcS{>74g$3=rM8cvb!k)dnRA<2^B^^ImAy+7;?|roepH* zRw!nhq@HhTXjqUZwJ{(lrk%u|{#Ml=Hr7fNQz%*X6NZMM(uggcJnWzx?7jyNWcTN@ zN|`y`ltR;-zq(*A#CC-@I5l-xx_aHMiX_h{LW&&J+`(1%IEIT=ch>({U!Oyz_&o3X z%+W{p%!F0@ei_zeFqk=r&ajk=6_>=mOX@TYfewGl?ZdbBPx0e4+^6 z3epg2jiw|vM(4MrLtDxdEzB0K$3jEUp>p!DpxjMGx8+z7`3`cCG-!5dXJQosX;+ zt4l?7Sw~6f34=$>-KwY|f>g^`C=o+zpGQQ@Mpcvh^_y#zTNjrhleb`e3R}{@TfYgn z;{rR_OB2k%Rh+u6=^hB04NX{{?@qZyF%on+2@%}5EQ@U~J7zflam(P<>8%ew9HBOt zM_ei48Mm^iBk7z8JE{R*mq+&R)VQku*FqX(L6a`C|b=^$<)c5_+(lL zfjZ}X9Zb38kO6nY=u_5Qv9sRI()odidxhFRrzy|C(DBl@Z(tf7J9S!TFyQv|cb7di z5*IUd-ROutk*xJD2Rk0tBnZUD(Mxk+34el#w5${Qq}^SNAZ zbETR`2EEyWGHg}yHdor!)m2n;dv_nGDhPK7#;pB(`ZDXxa#tjcJ0sERDuD16m0>?! zJEkl%(WuM{SV9$5NL6KifCEj0@ATbGka&!9}z}5(035pj6$f14y^2}ba zZ%IaTsp*5c&ET0JtQ(2@(V~4%Y4ANu(V$0UdK!YOU>H)762T#5JoVL&C@n4%V#c2v z7DyOQW~|7(F`hL!A%{Zx1eR)(eX2msa653|G_+PS#|_ZJJWcnU?1KvwluJ8qg4sL|TAaa8}SvF{cKOVWcozGhI_LBpcOqB0m@%d6Q^1c4!n~`{*!dgRKvwig`-~ z2n7$`hjwV+Ua*|h>=rb~*W_LDLaB<;%J;zEa82_OnY$RWZ8_E}p3czGH~mFUzhc*C zUkZ_yY932|Gz9(32P{!wvGNS;AW01$yF4ko`H-x{Rugg+-s65r%rk)+VbW>^Ngbjw zJ3?u!d2)V<-xZv!KMb?c@@p;NOsBwo#*>kiY5A*fjg9J)t!)T)m4Hs1?>O@kN1ru9 z4PTtbrL&fV&i6LzD|xfNd@K~$TeolDHtKs^?WK$uD+0kcUvR=9wsWBsr_CkvLIBmx z5Lizz|C2i2CCDAi(l#lj2xkLDQIOo(qfF7rF+Sqd3u~s0t<`5M>dn(O;Xj|@r2ZKwkhQ;+NQ<9jxxoC{!^5F?bwv)YFo1pzx zWL$z}`X!Jc{I;E3Q@l6|$}Vl^FwQX>wV` zTH#Z*#nOFxn`J{!1Jn_5)Cy;GZ8txSP@DZ=LOa8x!EZ)#tOPJD0yh5cz@bAn?6u}+ z<*n<~9_aDsQE*jE+JGtri-~A>9~&~WEV=ywky^pTPE3uOuLV@!K(?nJcxIC<_(qJ{#zddx=siY}mF!VyNl8}}7hug%#_4zE9e z7t*I>uV{KTls<|ZL^VLHifwSQw1muY-1jchDY5(cl>Tat;iCVdbWr3aD32=;>=sCm zxQkbV1Qi0m;sF?t#n==Ue*({oAY{Odok?)IYD<)c^y(ENNb^WzTnu?0N*}iwwHd&l}Flby}Q!L4%xxj%r081G_v4WqC9P1et}DiDeE=8>z~{3LyZ-WEz8_7GgbPrq;QBe^wrk z@iS6Sy&+nC;@4#M9G%#%kqH0AU;9fo0N3m>*p`vpv z7Tfs9Yl+jn{^hmUna}SqbeEyxhE%e2X&ISw_pFp3b7~0*m^_dFwMODiO;U-$XLwsV_d=Zc!LH7(iI z3MWEE2Mf-l@PlXNC=weH74%M+oV&tO#8}(gu!s6&`WDI~O9W7B1>%;>?&6gD`F!YR zdt?1?XYZF@8?`CgCs|kYm3VaVKZ@e!NulLl0G;;XOw=NDu~zXH=MNtZsH_1)p$Hko zwXfaG@uKn2>WDjMJe>I~9(KLBb1jeapDLeRrp0p!O$&*3zA@LyDh&Z4^YJgobRDOW zk_y88xt*Uf$NbzVVJeRl0M1tM(u}DV4(~54*rSP0kdD+7l{4?Irem1{A$|oq5V81- z;r&IS3{q;0=`}5?9TQLdCB(py4JfYq&4?@=v_49ZF;geV8Zdq9;-a8ff`fje7lIn) zrL}$@Fa8h=TIIyBiK4ygot>)g#r6_=@`ax)zVP;t%=Y`U{2%QA5*PF} zPNZ*LtajZVG*!_`rr7Xs_HkstrxF#jS0ox^*pdiX4k0 zDC}?o^67=zxLDdD3QJIRqw-uHO-lvr&N0=KmNAhli?K;~f-Nl`4@5Ywja|xy({*jl zKu^FLg1jZPUz@_a08|M`cj8i&7Sz>!Nc*{4YzjO_V*TI+$`77A8DH!< zaiMD>jTL2n&p%9ExO>>Z9zBF=s<7LN`sZhz+GV$C(-kexK@oI>sI9xXb`cXOL?>1w z5zxAC(>J`jP_V?pWyFHYF&pWzv|95>M1VZS$BE`{^lwpkIsV3S26>|>d+;=_z5hOI zsJhUZpc6fsG*Dz`reAI#M7P0Z$AGh!aV#`5=4t!&|8+6Z%NGx&Z?;q;wE(=m~9@lVQM6T8AEnZlWA$?dOSQlav`TApw^n< z+|=&)RLp1|6lc25CyRI~N@@(wEfnszHReg=VMQ}Cx)57WBB&J(3!tvC=EyUCX#)#e zLcfm8&X2P%@A6zc4JxwWtkSm_O$4&-{T5gb`6)ZL-n)0tilK7FS8xB&qK214&+XMU zk7!>|v3eBnuC;R1^%=|hjy0R~%<*rETm84%OgzCQ3T*da?zV#fmgM!J$h=s+zVB`E z{$O9ia?#pBpfO;|rj@5A4ZmWFzGLdQw*&8`r(b+p`e-pvJ%+P}wCUPpywa^^)5M0) zpR;n6W&+z0ye~q|;mN6vZ&ej(u{h=RjidOj$hX2S=s&a@)f-Gq5%zM&%2R5?bEhQ zr3bL>&H-o$-MrC5V$*lHyxWATastz(_}fK%eq+u0wQdxbfnKk*e3uzI@oLlH&e6_f z`4Pmr+G$9S#EXP`g79p-f8deGTC&(#?$u5;}boKnO{_IW`g-Gp(B#O zLrbtec+%&|?$5s5BD{_mK8ocj1d~9X)?TnlWEQ^vgG4Du@WSb|*+r8gk3WnaVg9Mk_m8>f`G!VPhjSopuLFJsxRZ+ z?Qv9VL4EWDXTH-PgTilqNPalHT$S8coC!#<2}p+!6pwnA7JkcM^{`Kvy3afc(QGQC zK2snXBqEJ)3^#C1GT&Wm`iSp|&!*FJIuenJ7Ppe-rP|{G#eM@?9C^&;`mK0Sv1+|0 zYnG|9teb?ya|71RzNGX5SxtIk?U#7AstwaE)V6&*PoPP$S1wGZnz|q8( zuYcO)yq|X!`pjl&yHjZEua0~54C+86+e3IG;*BsK$bn zf@xqeKHT}j-%4pHSrfiLw5*gT))2miT>#V5i1ZBbSrbxFHeKE;yi}q=Y#5mQdD7Yk z(t~L|CLXJ|c8KT&`Z*63@Eqbnf*4h=z#s~N5w1U5OOl;%8D?)jzcRd}J*dI`SoXAn zmD!ZM30t>=k@asF1r5g&fde5AAjtoBRerxtb@z&>ws5L!NbX`w-h3)tyih>nsWPx6xP+4BNO zy2xO}y8vbhpt->Ks+}>${n((XhO`mqLu9ib*dQvJ0|Xxr?9Lc!HyoV^hz0_nJ!3ZL zxH-Yo0`7AO?KHjlT>O||{%KpbvdJ?&zwxW0wh0gq<6&HPCkaDiEg2sSf_YQ z1$9_m(S}Gx@Wa>g%KP4PFbs5cn;_&P1k}QootSD&QCB!h@u;{@?aEz*3SjBa2mky| zI1$c*MXpFoMH1~inLW6|n%SGzgLkjbZFx;SBmo#id<;pF*p)~a?art(;MkIMET;H| zOwF;YyDH=2Bbm2QwzEMmrM0WB*dE> zGk(sX{zp_~EFL${j}N}FuAQXB6Hz@v_AdsrA(-H1sBMYIK@MF-J0Tx}jTSA^C*CTS zW*e)F7Sf%3OfFS}gq`0ui>nqf9KEc$j`k-M(GI~x2mu1}ndf?zt=NP4d8a9IFC5Syo*AFMT>UK^ zA=pftt1)>e?8@hMnbn-spwq4h?Gbkb*KD;voYmD^2*3oz%0~3(I{IuJ5w|Snl1a<3 z_maI8Lo;M{=nAmBTA^6bXNC9niunLb;9GNZB*9H2R3eM*)g!cX*-EsQh_13~j9fFc zLtg;dH$Yv46Uf9x-?@8NeCKX;;6uy=Hp_(F`k97zofAliMbkQgHl7CaFAA`x0?d7r zd2*ikTO0`v>3xZNXtZJI{2l_!5RX<)^iD#&%ERFbYPKL^qg0@0rYrc2oUQE7iYEzf zjhM%c?OH;FdP+|00r591p#6XbG>`NKG=SqhFtJPpX)U?B_6S9Ql+=WgFtoj3b#7o& zeZn+ya`iImBR_up7%Ho;t;VPgVLUjg;${UGA3D1*9wHQ>+)aHobkwYo9e?{RoW-(6 zbV#n6^i{CI#Y3V$NM2=KXUx{QX)ZqMTkV7;0K0rQ!$+H&FBDkp`g1N?Cmg@LGbJEG zsBDgQw{3DGrr!ltGQ@#+m2?y4gp4Jg9PJ!iEKy)blLT~5@;!zcynmVwlKE`W=c`xdi@ZO~y{pDA=*Z8xxKrJ~^LbAn-Z{sS2Bz z=Pfbs;SBpPFU0pHyL2(&(ge#!?GIWOl!F1{H@>u)MZHU;hIvX?1F141C2GvUrKa+1 z6rTE5_1ahvg)x(mz|5py)6_l5h9VYHKcQ(Y@9m2}oRM90Zkl%KA#W!t1;)vaR=ujz7w2L~RO^El#qx3dcv7pf7%l(u<;!;mR?cEo`!#feF%?0T zRJyV30(*-a8g_GBdzsrk^_Pz2DdLwGV^(-#$$o6nW+U$&stF*vpl%^IK+bLE~m)c~{WjseWp)zr% z`sUZ28yX~D^c125@1I~@y2`bCG!==|rpbJlva-_PcC#cEoV5}uUMh%M9lQhy39o)c z1#AVXSBaAln<8Y%bv(LpJ%h;Fkko#~k1^Cv01>oZ^ZUPR$G(W6t*jHg_p3UarOYK% zZAz09+Ez2R(&j%8{y$Gg?$GM>|M|zZj-}Jv4gIbi8#PmWZ%wEFdTJjfCCem+eqR+V zC4RK1L3OaieMzk8V!SR*hkM5Oz#7ZzA#r&W@L?O-tCPe~sW({xMX;&ZwtQQGCvZHA zBAmac#5O$+c+&NDdp3?l?~Dh0{PoG(C$@|C=5^m$G)lKr&itv}+yaa#s~8Vevr#NmGqMXOce~F#|MQg@{OkDq z*LU~${^vvf^WBvrJOag2{_odP_y2l;|NOv&=F|WF9lriQ^X;w%Qyo@NR8*{9B3jGl z|9jET)5{(I{d!g*Ei*wx>pbBP&!_IU`u7L(Wdi^2dK&-U*#FyF?fvPMqV9F%s`by` zThFZWHrLnhvqXCNthojgn>scrFUXJA-LWyD8wI=@l?m8Z_ zimy5yPI{4^>V98&<3xlq*D78++kah$8FfWbMYRui#eP3~B=yU-CpRwN_=9lZ|M5CA z!d)X+|Fl{cWcy@gFO<9^`BoS|UbQ)y`FmI5;*^?`jnsX7e9ni5-><4QTe|c%f`(PK zJ$o3>W1O29PSdyrvce08TU(w-@D5wX-V`TWv5Rst+J^@w2~h6t2sqs zB|;K&senw21~dOlH2gL?37PSKUX80T0%%sNG$&KIs6tLAr)Y$1&sA16zy10R>uaJ_ zFEra!(U(|&t22D>fST^fXu+a78NEQrcK&)986!J(disvZuQ%VwDcX%9=+V@XJ-ihn zB6@MelRwi9$6!pUD$N?SF9C!TNinfs#d zWS{Si)eV2EFsA7isdYaTHG37c{P^*ry4uw3k^AAUUV;Bi?$CW*trsrjKv{po0Xynn z%|+ET;;aqM{rG%2os2<)2M2vUQ)-`i{9M`5_ssDjd<<5$C|wuL3{}ylw)vd{>G7s5 zTk>QV|M#WUH|}yt?TvE)OWtzR|HNKofLm&FeTRe^y%=; zu^_U4m6VuOFLieQ(raeJ6P17WXpbJoOGbk(tm*rQL!TGYFg{p z9m825lB{`0blnf@&T+9pYh1jP!(BaP)cK>$Tec_>iuuMtapO&V9@JCB3E@r^GpI#& zg(V^8RG}EpS{XkGoZmab$OGj9F`<=-mtx#=e)+5Nh0@&A!!!rb`#dw^p^9=#zHOC77uZ~g&}++9>zHvgd1CPrVe z?hF^-k2tFFD2%G*i-1|b-lZE#u$!1lK>|}jGRR?N`daU^fq^7ad&zKa7Z#dqh^4@J zJ}K#I&J*lh!8Uv26=gw)=H}owmTC2>QofFenuXhSCVJx7?zS4h$*Id1Nbpl+J*KgiTPj+UE zMo4e<6^tMx`p^A-i&2`)++v8I6k|7DynTBR<6XTg@RG69#8dy-x%1DUIlDY*Q#($i z7;`GlT2PSJj~_eluTf^$JPy8#Fl?yK0&xiSO_85&j2?P2Q;bKW6=7O(A~k1*iTAH9 zlATx&4l4v_IVcQD{8CqLpFb#HC6mF}80zCat__n_{-X~whYD7bS$6Ky^$yTlKox{AUk?7*(elBXX_U zT^7t!nWCn4hw;=DAm;*}=I~xLv)pB@Y+Wtqtza@(Mh2Z~o9KqroHuXII6s@|hnU#Q zb{CT_g=8-aIa_#uOf^hn^c^98Lf+FXm+6cT?0JygHIXIHRK5QGdB*U`x zbZ-7Mzp>e^&HPCKXthnrrPjc}&18#869s88t&-Z`MdgctR0ArfqEcGo;GCx72}|w; zrS{8?rKj&bPWM5A7ExBD3x|j)wUYtd%N&4t z6raeS%-qDdTmpr+IVx-bHw}YP?7T0+E?QhuyTO;ejujsLi@_!0-AJ432`EngDE~(N zMc>5a%ijqk*5|KW`I8~0TA2)dYY#@InDMRb@L_AkhHo`eJ$h6ZupO1Qepnu6sgj^@ zh>1%^R#u~+rg(8-vXdqm)ww;spy?4<;FQ*}^}{>egJJp#LS>qYr;a5G$;>R6INK!M zsUp4Gn59=FunSMY26MA^*`l@=oqOO}O_QOBvgFn{*L&@A+Oy-o^Nk!7_ zPPj@x|FVG&u7TWA+cDE;;WRNFxk}=n2DZ+x{_Y|)p2Z+3rVHu$AGw&Ne`C1_{d_A3 zkTw81rsiF^#F&=_@4YvrZ%?EuH&o1frNF_M89^Hp!)+gw{N!fu3K3omkdL2V?X>Ry zhFv}kx}8HeZ*J^fLZ*bk<>N9yF<-uae@s-UtIQ7$k!jg zN2q9tm8)4_#$<@;w?>`;3*xR}zC^1*{BmYUW#;C0`Ll<2>j*iymF!x@jN;z=SiF#Z z``SP9SCN9Rx@Y;gzNt^LvbrI(qlp^qi@G{59`RkKC-rpyt+=?D93`C^FokPF`M&TSED|GmmoHwr#$PZuBs=uziI5yIEvbnuoo)7RUMarQiN zWIwwP>3?M8HzBtnW_VH#ooM#Q<{N_$P*h>$=@pn5JfIAPTI_+z`%%H^iq4N2K_CJ= z&+Y>zb?(!rk5;DJ;rmaQGThFn1FeXW@4wy_*-xY`TK4sH9$6(63r-L!-z%xgDysJu z6Fw2fM-wMqK4|AY*e{l(eMwZ&(!VjeNsK^GP^yvCXotx;$R|}mVY0T={kCx}(;V^i zQhn90vw5+rhr-srdA|qejC8St7&r_lntnq}N&+-p1#oHhW+xGX+{lrUgN7)t=`7dN z-DFG8wkW9j2s1Ko-MYp6*3VTXjNhIB;>FN~B^&2{bXz<3o zGIv`%BSOppmBneUGtarUx(FYxGc!XN`aTfA?gHQqJDGW)uDUEmU9`@ih_Cy;i~E~5 zZ^Y~#Dw-ZRna)|Zt5zA<44gb)u4mHCn_?CNWQV^oRt~>r{YaJyRpiz50#h^~W*$sX zS^wg%4oo=c&qwi#zZ2P=cYureWhkSB(hxrStLbE3UY?lRouI@j{UIf__|PgIq6sXK8s zFk|>9w$S=Hgc?|pitQ@Y5_6umJX@+~ke{O3eh)>0kk_*{*GP_W9w{3+TW7%n z6`5mBw`>n58T}k}iCBs9y=Q1<1)x4)b57L+3SU%3^&(n0^xc zl^}T~i00Gl_Cu)oM(6npmkG8Y2FNEk_|J_H3~0YzZNmxYKn>5r}im)@IVqLZM5v@AetDQJt+%DOwT=4YTs=%?wWV) z>5-M0IeYu|IO$D&W7@EBVU*H0r3t4eQ9r2abq4+A43Iq78v27BAQ8YS7i0RaKo z%iA(|&2&e!R)do$lyzcot417r1PvIQD(RHiEU|(b!b}wW2192KHChMK<>0T7SNjKOHs2yPy}C$|Cu`-k?zTMvZbA`S>Yhhx5uwOZ^jxzu0~k#96t+S3tr9~8X~Ntd z+n>n!J;ao(JU*?JnMX&!OnfW+Q#Z}iG%6KvNjPc2{o4@z-731ZoYzAhM281EGsb9_4_*6!bYYZDD|HFNP`~9iXD4@P+2e zKIi<{H)=P@6#+8F{+Jtby?pmTpceroHJ-XlIOrt`KZDkw~uUlwJT7m0!2fPwdI82*L~LGHQI>me7nClbxh zv9pWSk4`inEMjAV_~4~Edduz7+pNBUQ>b1Kuf!l5f%pd}p*%Ef^ymdqH`*dD#yE>f zEMhX2N09PbNg+*beDe1mgbZRd+SJFZD_z}Z$QM6-dY%R`v#e5@v4W>KUFuSU*ph@0 zD6Ns1(^oyp=7goa{rKgB7HS%qw=HTX0o%|g7hW)({HoJ~gUq6Iq&}&H`3fF>FVt|9 zpksGQGs>S!be==kolX(Nyp`+MU;32l<0G9AuzH%uqfMVaeVTxLWK7*eq^54Fmw(}2N44Ykq!$zz*N!*Z;o{OIW2C&ijfyH`v7HQo z>nLT4A$S_~h~iFzBz-CV<>48(8gmzjQLACKa7p0rADUel?FpqqR=Nlk?}WO!KEx`0 zJsv)MIERyNvc={q^0kaz-{9mmEZ_pUA_JB*-fRO7GQfoFk6EO-=8v|(APc}($bM?f zHTQb>`0={FoiX7CSFc@Dg6zlN1%CyPU?|?x+otV^>7pF<78m`l*D?CtZks9anRBcA zXFfj?NY9DD7WTg(X$V$&i)pv^`T>h;AF<+$ykBN8S1RX;g?KK>*Fvw#DjE~sCin-4 zEqn%3SkiJgL!?tEn(qN>`Rh~M)vI@a0!N##UsCDH+q{Glq!E{VtBCQ9*@}I4Y3u0R zfx#A^`s8?>z{5d0qc@+eJ4SxG5;WUDr3oxVXe6qY+C*L$#Bc0ctRTG&Syi_c&7_3g z*kdZ2Ok8Uz?uXd`pjp=$6txuQ=Jv$mM{Tw+GS$#%pW51N6Zk=8O-3hU;SfH()(&23 zp1c=Kk!Z5ZJvB8o>!VYdocQ3u10&BZ|K#nDL@9gIytU>tio!>mxtL*%3mX7zjeke( za5^gN!=!EO?Y+pc?v$04@mxU_!yS*^d;Ama!3olmZrJV}`6Z4veQXa5UyYwv%i8+= zi!DXe>akuKKxMyu)YO-a9Ne~k$9hXkuOzNgE4TM(yURIC({@iMu^%s8%e+ns{5cy2 zuR$BapxQsVEtSAiM-IYt-ZU`}n*SUNmfxo$^5n^r!}Cj*1`*(ny!A5jp;QwgMM!9< zrh>`qr&@Xeal%oE*A}^w0#jMC&HGqt~N-S!;j zmz(7?Mgb_bK}Ic*y82jn9X{MxL9fbPembNjkH*Z!m+(+)oHb|efB|y5wA8%zn-~T+ zsJ6N~;`0xF_EtAo|HU!YvOhn(PzN`@1uwGGwr-r3$pGXJTAN$B{f|ml` zQ(#Co+|pwXW8->6$HWY2*=T8bcf90JSMr4=&ywCaBu|qk9DfOJuMx)~%zGGggr#6$ zSZwZR0Y>gT^l{RhTk=yJtN!GaJg6twOS!DAxyJM87CT_nYI7Q5axEXu{Q9DUu$S8^ zj;T(1yz1S)w!^Q!@WS=U!f>+^{2??EVR9aB;zoLoE86tM$>EJGV<+o>xDm_J*UrFU96(dGnz- zTgi+x3L6^okE&0)xpk!e;ttnlI8%&?CQqFjlhFR}00)L_uBPH-_Jl%sc}RS`kpheZ zZ~ht2T((C+Gt+XN+%@+>!jjxNMD_0-?Vak5NeGv?eZ|Sc_XpM6uav!DtuSrcAkILH zV=18-u-Jakf@#yId;9xGI-g@u@S64O?-mpk=tpqT{qtd%73>aZaR7SgN5?y1T~=9{ zg;nt8xpO;;>=cI97`h5AL}dU+OnaogeGpd`45!9Nojv3dn&WSBr{;LUatgXJh;g$K&dm>~oBu;0(GFNk=; z17aRt_w_Z8Y&(1QY&O*Ln7069)C&{!qm2}MTW`_;zzRgHx(`JHjW_@=rlsxW!auO? zf1{GH;Y32#nW`;ffCMORbis2|Q&YvJz@S1Lm{19b`sRCv;f4VJqs@n)rwA)mbe&391H za0h#N>9izz>0XL+5?{X`X$M*iis=8b<=5UkIJCJqij!~VI!p+#W(ZrjddSbC0tV~_ zvi>KjN-xi#Nmf0j)UxiA3W$hNI?8sU=hdsq90!Z{qt1AdeDr5E;>{|oh@+kY^tM&m zmZt8UDsi5ACXtIMc${~RC4>p4RKpx@wKrE?8rI2;81WJiLCj){wMdDaYymuqev2D3 zd!Lt0Dckf6JDg%@wNgk45?BL7cf6wcR!7I37iBg>2o9IviE zko_ZebJ><%!bBq$tSB9orS%EaS;^s*=1aXAG>KP(rTfONU%S@IxavndLA;OwuKfNh z&TAC6#w2DeW8&fUT+&xY1sqrnOzw$Vxkl7Ce(h1IP}z<@5~~bgqZqK)+luHe;Gu8GgCKv^2)w4AXH9B6RJCz3M0dTtWASNkuD3Xxmgi z`l13ipCREg`Cdx_NtB{!wj5#}FvUTZ_&!`~bz6*df~j+q9*wG>V&*75*H^7)^9D~n z=JH$(m!3+L^OAP0-fbIzUJuLLM;G2y_Y|ykO=;$0)LfR3qU6t2_4^Ib@!yCwDc;{w3&buw$tMXOc#Ws+;6*I z)X8oE2s_*sxhy$DW*uWAA`&rl`@RhcuFE9(9WGZPgprQ zIWME$Iy40<=R5F;P(a8X_u@i8TKZSnGgVrfIksd{m##Z8KEo68_Iau%))hz={Wbhk zn;!``8;$J-)N)%rcI;RsII`?`*cU??jL7i^LZ8DZ>%tYA6G3limlJ5(-Sl)VzeaI% zX{uRvz!G3olw3Qkn^U4gRg`Sq4Vsjt5_31in~n1uK5=40&dg*nt5EP}{$M z|ND<0AB>8MD%(uWS~>LAt!crQcUncUnoB9jq5s@R#!YqQZHZe<&e5q;(#q)61B9KF zSE*`%*A@o{@S+K=wELgOC~GbI$FceVZSEKE^^~~*DLec45DES9NBX6p-9qQk*hOGk zTZ}+6a5Z{UcGJZn$_*>ELdWK;8zdnnq63BAcwn>L#pNXfrtk33qOOnM9~K(Al1?n& z*acyR?50jV-9vNxEh?glH&zrrj5~1t(>b%F&6_@bx+5A}%POJ2#=)S&NoEU<@|*c= z=P?Li&YL^eZPQMw2VGBhGggJzv=s5cuv+ZrQu?wVd&>szsNHMYf7Z+jf_{W9Ev=S7ss|guCeVo+?S@_d>~E#Hwg^S*Y*)J-MSK>DAuX?{c(WP+AO#I=$Ma9IhXNP zsw*cOmFfn~7_59II`^}Hb_n|yy9A}AOc#V_hgPy1h$6-#<|Z8dc1LuFz>?{A;>3xg zi6h&rsQ1(fh3m%;95_&5Crmf%BPD~vy9=(*b3w7?Zi?au@{jkei`9^ z_Atm@^U@9EMHEmE=?~IAM5HS@TXW)QQLb;dv5yF z-jW>=HtDo}{d(QAQw%c7_xyTPKiqeJp{MvBKQr>=hr`KsP)ePRL}dCm7%&r{)IgA2 zNfX|kZ)7;eeF2-2w$`27bnEiv%k8e^T*etd$h)0y)wxG$rHNMK)KNSa#`qXbwP1rm zB@)G70@DGNCm+5!J!6;J>&;We=;xO^gHBOFr2MTg_&YPS1kQGB%jxFn?I`Pt62I(3 z>Xu=RNt=HE{dX~ecEca)>gr<3Ef&nWA~!G3J20@@k|j$5+wA!%jIvel6@ioIWNs3J zl@X5#v*eJoclKsagBjk_V`6^aOik4j0P0`XlWYSxHCdu;UltL7ubm4 z(GmLNFrHShu-8IgMwrMs8qP=gB6%GI@pG*Y?3K z(7N4eKWVHhU!l-yg}{N<_U+rlYJt3Nqi*2#7A#d(HgA=%C;d-6eDjq)Q^=2s#{?EA zPQ|$eub&-~I8Kp|Uh6JThym6}()?<6??|6ZhFOVxhHu92UJUB1tL=Zm67X}Fh-y}= z*viecr5ACE%O^P@T&^^kvqk8Wfy<6J<6J$;qhTe8c~XQ1{u2h`)$$m;GS}KVQa`$C z{dm%7nujmIMi62Y+0*HiA4ie=5XoSddu>=AS)LEk4ac% zV-sb3PE#{mfD_Q&KXp+{mP7LGvLX-jL8q6tnZbNF~B>YI7dVXE~B?q6VeX1UehX zEa5fl)>$7-+4X(0n1w6S!zdN11_owkk&f?3Bnqs@_23l19@B(?P*zEeE++KiLJsda zpI?Wk-!5~R5K{O8D0O>H>hX?zROAs@nHO%|dR|pqbMxjkVqpOV;FPfx?s3N6W(UlFXE9raDLWEE5wH{Ti1RaUk|ZO)U_aQSh4GmA|i^V_2IMTq) zk&1;Y|Ljri7bfW37hK^bD7I)t@vMOaoz69ZWe!LyK7H!cK_IPsrJ|xjOtlZ5va5%<1GG%4Zb76tvT37*#kt3| z|2*sAlpq8@uP`RyaV>9Ulr%+t;>5*mpYiAzefK8yc*gJ*X{md!4zne7q=&Tw+%C14 zRkpU>xbM)57bWjizJ7h4mOE=pD)mMmDg_F0>Qu6C7Hk<&r5I4r`0X59T4aW6CesCZ z6t;^|vK@pF+99#{x|KI>I#%za0AY3X=H@q&w1P)8sIhMk7Q(QId1ZP$q72|I4oRcM zmxiq&X}#knNONBq&Ee~bEZ7h*{KBbTG!b>~tzLW0az7^`C};?pl%Wl3ps{!2)~;Qv zcA1so)s_OHvQ@pjUk3@e&v)XmpKU)}V;66$YZ&;Y30~U~jKxEoUVqYQn1pT>4WB-3 zqE9jSY2>5|iSx!Crg&F|ViFZwwxIg13PaP%_gMw?^;MKI2t*PbaAU_a*U*KK<1$uO zwz+Sc1)D{hI0HO3%o4p-EnmsNcIV06+CZ8Cf5EN*F}z2Rec)By^Ba~gjz+5f%98_# zvr43k9>W^q%6((6Tjh146JfLOC&6?S*g4gn$>W9(??6Ulwsp9xnW^cSk0ZYUg-6Ak zXlWhXM4spr6)$2&VYOhNO$JCBDZQ-Vi3LaWPMq=7dT&C>=i7Gdm?P^wD_Fm?Sk``1 zhPUA7{wv{)@d)%9H1$GUc+(JJMODU1luVbmYHnW)3K9o^6=q3(o8=Bf`nVJW;^XklDB zf8nA50ofYQ@9SRvw~`F@n-zxcCM;td;QgB1KYf|JGHrp zA8%hES8EU!a{uSB=JPR6DhhduUBk#S?!;Y~7ghaC-HRpVX`seHAeNxJqK{SoImJT|Q%8$=i~U|_m@Zou!jV3B@NZaY}M2vkr#*Q9+7p7jfvpMAi zyk#zxgQJ(<_>_|X%~a6mw@p0m=y&#P^rv`1iHC?-Ydh9*m-iTq>(}t~$ka{cw_GMt z1JmD#@PQc6*mf?&x{gD@)@QPu!NthA)1W|I+w^WvyAJSN0^5FLjU6LT{~){#eoC@7 zX!`d2+x=UZd@}$#>D=wx)o6~PIp6@8%KxJ7&BM8F+qPd_bv4hWsL-g=oKz?knlzh= zqC(14NrXaOP38tvDveYKQ7F-j22o0gD3Z(xDUtO#UC;fjcfH&5{=3$;e%p3;b%&qd z_xnB1<2;Ui-;ezO3mY_{lii5--t1@HU&>}S{eAI zVWUTvdQ?D&Xb|vL?U=mqQpRG;Pv)|D?Tt&3pUvF;U_=33T0S@+;S-pOARZS@O&IO9x;(>qMwZ67#uFP z4i5Ut4tSyje2B&b<_Oe4nbJT4(Z$=OI^;DMXsl3)=r>@1eE)15a?rAu-8w`PjSKL0 z^z#=TCrz5P6f@cdr*{zg8pN)(~PEM2HpR-9%& z^duMt%U6D1gez1cCWHZ=r@1%Y96wgwu_!G+i;s2v3Ge@?HrDSEoPL#IsXBW!3Ynz)&L#NoH}BP7il>>4wth;V;u zibjo9oMo2QmN(NUPF!Bob~De=7$MZVq*q;&GXfPmQ5fxo%rLH->G*A^v0ICM$ihps z`}W0I->o6x+`VOiQ$ybGD#PFH0buLZ9nGj54rva5T_K{-!Mk@=DI}Fr&dqU?CFywU z?BnU#^^-{{5J5$=Dx&s*;J#kln`!MocqDKn9S_`xnokGO0@mN$vYLf?>GW_##eTS; z^!BW3%aWRr(kk5Tu5c6m;+nNYLFVY84k1Z zdTgwcV5dV&c|Mxt%Fmk_KOb)8sQ{9#b!~(_!fbJa>)RTI7Rd8x0Qb;GjFoX?6!NF) z>c~d>q@<*SMY=8$^bAg&A>?zc`GZ~6y$~ld{usUY(w{Y&_FcPn{f3XyM}|-TBY>iU z_-*U0u^g&76Q!FA=&>Cyi4ap!ZclX2K7sCLWpo{<^G*eaKZINV0k6m5ueCa&)a~cD zVYr}Bz0T=es0cl@zvVZ#_szE-Ki<8#Ne;5gKGv<(^nQ#rGno$&1<4O@Iyk zccbPI!pYlfBd4O`0J)ftYVQUEFGD6NU?Am*VN-iDPm)d(j_Cbu6uVp9%NZLv0*}U4 zWQC9JKRV`hXd3(FplMQK;sMi=9m=p6hrtCc^}gman_Se>Gtaf{`fDmVaW-eRUynmE zv(sx&^_|iKcZWMWI&J_;a-MUM?;)*ob}r(X_VvYjMp9Iuyh7}$5iu?GhWsBpVP{-? zd}_`8R$p)hJ&U07$XmDii}PotH8*Y5eOH%=FDtaP)Pb3ADQu8a**S)Lr&9*3t&FAg zg_!LH+1UemSc&?-Ja*XIOM1R(0O{#7yG^dLfcZ@G#ceNfqy!biX^zGsFg;KNBQnuA zam&#me8~EjL=a%_Lq0w}EB8=nAL=gs59m`dVe?T1B;`gp8~T=#!Pk(8hfUiG{G4-h zbFXeeK0`;1`dgSdyg$XB9&GJ$%y-0t!i6o*%%(5o_Q;)~4*9WF@_wss- z#^E(07NH*AG{brm)id2UJv+}?(WbwnHl$T}rbFea?TKFtxLtjwZ_Rh-5ts@lHR!UD zkx`!k1I&;5$&nSLy{3aDR^{aeIIgc*;RK7@7a9x!(nL$($I{ZnoDO4WAG`xL{1Z;u zUL`c@iZX3I0t*y%I>qJzlj5SHZ+!bB;93s*zUf!-BN39~VFCIR?bgz@LSCR-pvr`V zkjAq8v9Yl$^GUf6-fv!@sTpP<^65+OheOlAp6L9(P*YZBaihWS`0)d#yQ9Dl8VF^gZNa!eDDT;#~0x42Rb5 zI&gUV2Lfaj!WGcm6?=YUWMrh$Rn=Qs+5;5&6XArn%ThQACTD`i?;BRG#!yY`sNe}%nsNVL!gbHRwlj-XJJX{;$+tsHgua3>!B)rQ4A zg?~{WV690*&U1w*Pi0Ah7m7TNBN~a|=PzG|j~?BLZU~x+9?`wg8#B=#MqX3#vi?gj zth;yr)-t`{`bU+`cO47demCsV*pBPJT3h=cuE#c%$~U#EG|4)Jsy;*Q_tL7+!xDnkB4+hM zaIdE}Pc&|d%Z7auttQPh|2%4Z7rwd|8IdxeB!QxF>Ln$O*aYBnOJko#x_qNfyFq4^;+wGM5Ks5r)$skA!iC0E$oe&n^0~$Ak&?wpv zBDw{kmQoJ%_ZMS?1f~$an}8|y``fo~6$Yxbe?TW(pML%JGn7;zL*%(zcy2UWq`+f| zE`j|KI>-yH-{0y2uV#8tF{?8=`-c=>M`4`^q&Huo^yg*i&l`=vIevWCAht9`fc{5- zZlRMk_=>XO!2P>FxAHEWt%}c|J%;2GNgFM0I^o@d+I?iLo3HvhK0s>6RhImxdp~#uNx~YmMA(OV zQ>}00mqK^Fevxl%Y)tonf`DTO?AMOnv~i!=32ak2Dz;HLq&;EBL^~zqfhqJt&OH}ZOxzCQr5tg;c0I5@ZAIM?_Y4;^97vZ3 zk1^HW%m$=8KRk1YcIc18gI!4cFztxSOnBa`v+nwl!X^D^mjp3}MwF|rt|l*22I1QuYc7jZEs-q7q}L{$HFyhqy%{jes*p4%T=VVWUS%`u#((MugfLokt|3%#TxCt zudO4t?Wh0y53?WC{>T0NY^gyZ*++18gKoo>{evw6r|l=6N%=Hx&Z3MqiOzrj;Q##> z^TloHZreWi-~R9y*w+8_3LGWilmEQ3)S*uol8wu?*wLuhrV#$m`|EnNLEJ2ve_zFS zo$3e%I#mD&C8oleh?W%QEtv}~Xn?J9^v;qB>RQlxoW72-x z?n=+31)8SKVOuUeG`{3fkYe0#*~sW*P4A>s&mQ9j^nd94_h9#WSJz(6kJtChQNAfz zQsX2xs6hWY ze0brz5vWYwj{RSjUyIlOzAv%kAP@yM*jZa&?}7!wf4`o~ZvKB|ftvltCP}RWx%fje zdx3V%a{MNNoE8D;+#NfnroDVA%u{pd?%Y|Hx!u8G0d}V=@vlXL!P)rVpZVXvT4yBm7Hd7@pUx--z#I7s9ww9FC4N zNts0JtIIXxM8|zl<FRiMKzf`GRMZXYsYlcs-smti+(W`jbM zB2ChOI9gr{%zgu^gO1~Y3|El^r|bnnDR7n*(r%7W5{JNft*n`)X*6P!hGwvwu#wm+)%)BLcPwCRp{rhcjh+$$*!(bqD z@XE)Fr2D)+Tn?Q6t-gMmg|O0orjlaI4*{KRuy}TS=%qyFt~rrT>`$4lXrA@}ot(JA z$7H`fJ@j-D1#53ILID(a=wQ}zkMdgzjVos%vI46|-HFcM`I_oR4dx>%zvEQ<@GCccZ)oEB~tuo$)LD=#1ZvxnC1P80t(q{L zJ@mgeuMjXHs~Y9!j?4A%W^J#UCA9i{dVKgbK$$Ln!Qt@R{ASZ{MakX98^mv@C4H${ab23l3fNA7ZYu%&g~06sfDd`w-Pgat-LaMzGbs0>V!7P- zZ-S#KaW&8XKUx6IZ6EHt$ssa0aP{iwZf%!*`8fUhWTy|oHGw`R+Y9Da(-2AAs)=^! zA4M9+l9*wyz(32yQkx30n2$W!LY53$9R`YG$s}`@xC~z36R4LiQ`|W=)rW@ah~<+R z&Mif+wt-sy-@0+pIF8Ea2f;*MU0F65e&vjcB=!A~OyS&s2!noAovVLIM7H8 z*nxlFNG!!S;F6+qUu?g(p<+9h>1k{CrRT#I5=-YQCHKF2lO3Tv-Us+|psjEFtzR$4 zfLRLpAk7?CmY2*%aQy&cdRY)s5b;`d4kLa~it+DP%=lGJUol3w436X2P#*B_2hdIy zo)D|jh0+WW$p##?|Ld)VF9+r;mPTLKG18Z&=u5c6kI@e9ObqBPCwGLNehd?eNF;Bg z0!K9*>p2IM_tw@AgM{a@MV&@99p{cQETe+^d z#q|o1ul?kgK0pt#ET}7={Ma6?JNga+#Qi3#C$HSAwM*-2KQaTln4%*PWp?yNl|NlMCa zx?rf}(IdZ6NGJGqBq9~m)1#p&*Iir)m5MCdFF#`wyf*+q{>qM8p9&Il1MjcIJ-@V_ zf3v_5uKC{-$!pmb25jP7;kR`gl>l!gtZx&#Jg4^&3D|b(CyGc>N8nQ41Lt?`yz{f1 z`XZ_v2nRgz1FWlfy-pQiF1`6@jZK(qJQSrJqIl7=4h%3!Z~iNInMcU2X5JOqU)g-o z){mq&-ci{9{F@5zqgAPh6`Yx8Ze}L*!J<98;JerYQ%7=L(rf5DUv*LaZ{2Gv+a3b- zjZgxLZLi)rRLZTnw$p&|Yb5OK?AmW$qM4%4N)nn8)4>$|1n2r%mEAwa^BLqn+{2O` zvzKmSEQHQ|M$tS#)P@0bCt)wdy7zy-73H(YN1OIq=iHbVz2%VNRq^zs&r{Vf^2P(z z?In7tPSG25?`Kgr;lQGj0)U)uviZ^x=9g#nk4jR;@?k7uvIb-zhgocwDGu?Po5n{F zj|}yQmQ4+{7RSY-%;O;%nsR4AafXf^yV&#$0PSG$vn?O{bUSuTa?DBzC~9>^-%SSh z55;t6xac7wx)eUh?Q<*|5WNR@m7@J1O*y4u9pG*eA`QzDQ6G>&LJWpVDp`4Nj}vX> zorPHqra+4Eo@yfaJPcU9Bz_Jy+wDY_)#O-(!Q{Y6#{QOTdB#{9u8caE9>_s?5LLiKWKm%) zkPu;w0xcJ>?qzMrv|D5ioVYTxm*}m=!m7-(NX6GI}!F@Wvhn0(jAQy(soWoXuc@O^*}U4T;w( z;{YZc>4itw?h)*R+DOjbqg!`ZyqnP8ye)FhIe~s{%p{?wCw7QxG&yeSk^;8b>))WR z{{TCN#?hue4-%>S**@Q1&M4eK_YC;RtDK2!Yj*B@QmOj)Zo~~jwJ=Vb<@CJ+avEU! z{kY{35LaSbE%pG~8LZR-jDO2lp{N&DGXhEj&^iQE)_K?lBu1?Bi$v0EjBR<>dJ_Tj zI;=gc?k`%lOz`DHFt4I5SlA7+)$f>}XQJHS82j6DL=tWhr}UXVeL9foTqdNgWav8c zoeJJs3^H1}v^(Dh2XN@*r1c^k;fHk?o{m82h3@DAj@OeF72hbh>gx6vITC3_5n%Wt zEvm0ny@rHQWM?Fi*F4>51RtWdM~YpO_CA}>7Sr{bxVh8;WT^dNj5G+N zJ?fVho;-GJA1Qnh=k;wGvxQ!q-iZ;|OAvoyY4?hiWn&}cTjDmN;$+DT5c_}uGE$>f zXgHFX`^wGnr2*g*IXw!1KFCaO7PK<$n_fAccXs%yE`$QId$~B9bW%8%pkwTd+xk3O z=P_sZ#iRZbm4F~Z@cOqYpEmZSpm;*wGpVmVXV&!X)3*hT)#*sVlLFo>nFpG^a#C_S z`|c+Qoe-3sMZv&U@XTY9b?>cPw|=Ev2{PD=w=O?>_}-HzPsZGH{qp6D5Kg7Lz(IW? zjtU1}&e}t;9pkD8a@CsP@xmsHhh=H>6L)*3lAI;+6e4kyBn|*j zwiNh;i^c;n$VRU?Q@t~7MbgB|WI$z+&nJ=;R;y!ClU1`h_ax)QoMifByX5_S999p{u<__W`KjNXMVM##T-R%!y6L&7^b95aAyHS>=2U$co(fR3krnrovAeha zoKLuYdjo1#CRY0B=M8>A3>1`%m9-t`8v{T3sCULKVvrD9(vqatxM}NFUm|QV#>|dU zjw(7PCc{Oa1v-PIT~pM3$e};NOCLh-t)I7fG-n)#Rx)d(_VcrTNl#EJPK62FC(6y) zIaTH5o_yKDx(<`U3rJKAlRNEYofbotCv(t+bh~>K%Ef~xx`CGG8W}4KjjG?j$mT&) zjutH&*_!VwGQpnUWhsC)l6e#Gl7vuKR9`Tc+&UWYDA@sR#SUS+9DmEBjR<5S zo6f4mM^1RWhSoiVhRq#0{~}L6SqyQBeo1s+c}cU2#J2Nq_?m}s+a?@jw%{^o%3={v z2#JE7H7ZSSp;BOwRaoTTXrV7s)Ch+)LAfv5HhSFb*)nK|lKJqqbgdd;{WWVq)HcLXjjuaduNzGwWVDqw_U9h-S3y1xpWlfEd{F4b^ z!gbTu*G5p`h~Tap#&+eF=}F#aNa;_cQu8O6%ChjXdinX z%1L9`*hfZYf-o`?*>&hWe_1ZBKg+eBsO(!<|3(Vc?DbpDuHu&HHk1CjlmBMb_UqgC z0JNgH&TW00FFdSU&Uqagt?O~`LkJGziY98mMR8M@|?n%dMmwF%;VLr3@1o>+M<1ll1k+l zQKEZ-#dT3sR19iePg*F%A~hwo;McE1mJCd)>c@lQ{F|R;r9y9Y^6kJW#JHhaC&q`K zT=z09O)Sd!CMK7{_V`#3?mRPu^FPVXK;i4{>|{j5A5Jd*(^wYJp}&a1OT-=T%sG>k zsCK{Q$C)wu*H0DvD0EY!tQyG$M(PPK0UhsFF^@)4m2}lxcDCWEGiNFqei_oxkF&23 z97KHYtN!7)6N}^#4eBa$V$V+?RrCK%NwB1RCPC)!1urbXt+bVyGOk;pR$X|1l2BE~ z7vJIIdGQIo$oE435O1TT-e8jndvi>FVIsx8=MZ?7k_zVk4YUPD_Fq;~Ha{RRn` zWF3d}x!&sTq)z3tmKza%u!g&5ZHL%v?==R?Ew2@K1nE+68wKobVy`xRTq=YUE<+7OJ{T}Fy2p7Y3cox zadw%=NkS3LS7(*~QnJojsc?@qu3wxEyrz9V3R7?(uVJ{j*Y883ZofM7jb_rvk0!` zq>|R(a;=`U{$Brrf`Uiz@Y^!(KzNJ6SmN@E5Fo6sEG1;CyT_L*0#+XV)BOb}&CNWUy*kI0)NympB!*9>l6De}JcAn$IUO=8@E{X|h9FgZ*& z4+wzH<*@2vG91qr3GKIS?*2A4=Ix>Wr~}AmHgG8_$q+>(QoTTt3eIfYMP|B%f!SfX zXYFE`xF8@kTLbG+%mI|bZ5aJSXqI=1mcW%s5J3D_Saz{BO`wxWG#Dd(1xY9$WsW3D zG@O#`XnlBgy%(H|Fu<(;A-%v!XmEnx_mES||GH!xrU*vG-Dd?-&Fa4V)?Nbr*+Ldd zG`*cc=631JL2S_c{6%YQ{ymqqQPEw%MJU{4ZkQx@Ez=t~4dK~5{%?$Vxf3c}ASfYN z3Nqcy=afc86dvzbgHNpZ3Bd1*Cmq#03!(#g1DHtGlBGJr!MB#_(qoV(1mbesSINpu3dy3kAJmtA9#@G8BF6wXDGDZ~8 zoKzLl&e%C{LBbahap5PbTW6n^6__NdBi`>LC%5gDD@la=+o|rzTx9f$QfK;gjF(XgXSL}3uN#}h;po`?{Do36XfO^Zy%8G<$Ub?Cf zu92;3MO?77W=E6bDT(ta)N1KeV4BJJZHq~V;f>b~fer-sL+%yTXs11V>a1K_DVrm9 zkX`NYIwVnEJE=JLL0*j_r$kxBQLXT?=9yP|4pUPzwjC9}lsewfNNx^OKi`&$X`LWZ z7gX*|5flG}uXVYVVXHrd5a*Lf%M8Zt?mTR=X-ap2(1HzsFP{?0|@B6%>`%ZNoS-s-O?`}Bx2NLs-q^e45_B1(; zl9w+6qt&+AUH^6CN9#W?^Ye$$^C9dy!hVXj_MQ2dq?LIEB?z35sR-_=H4FG()8M^*sXGLJq4-`m@XL zNHXrnIFRwHl@>LN_9s#bXPJ7T2gw|$;gfL{#}HJuQp!7m3ywr78d3hUE+0|LD8(s$ zHXDg!W2O!Huq~#HSuRg~Gibw{XK!t4nCLh~3r2Q?3=-PhvFr1Df@r`D~yHvGVYU4<@(Njg8WqdJxeEu=t)Xw=wx z^yh;-lbZq|hYvaVuxaS<$H#*@qc>cvm(Ter9h{$8nw?d-{E>no4BZ%?W2nn?28;qN z=(?ilD+|`(vW?j5sKiPTv2kMuKv#6uTVDyk=!~h)mLu=b@D0<@IEVFZJ{ol_C z)cAlDZm8|?Y~}A~?Ag~y`jUDsx<3wp8{7Y|_`2HK*LmaL#eg=wM)>NNe!>>tVdgcr zqw#vWmKz2CT-HsGsaAMZCiU1M+XnXcIwW;|1*Sk`S+#sr%QGi#A)}!8p$Fjf%jjiP z!K+~RlGnWN-!4?<{OaA-=jmi72^kxAGbTnI`d2r3dYm{PT4t^8yiY76!7UPe>j$Z} z)oo;i;RV&WU5Y*PI^AM#eE5o&XO-VqlY|+N8&rL2*5RW%M96m z$2LAK$f+(W8ZM(qV=PBk=I_q~l?^aqqlfNX@n%bA?X5%H`~UIdqKqe=jzp~&qn*7 zQnr&vscG#mDJNh^-eU0Q)5$_=O#y09W9#ws8)7jtbuy96UAJEJ@3z^H&my|7Y~Ex( zXNO&_?YSdR{Uh?_to=cjE2g3p)RPe{hs8{^_pG8Y_=ox_?0Jdlumh8S?#&itkYz`W zVI4a5lC}O;a{DIdfvhN&7u3vngReA1_@=j`N>E1%^QELzX2)K9)-AQ{@u<7ED8XNL zKkH#jK_H6e1^2^M#pb%}pcni6FPF2JMnuzYos_KjWZAUVdPtfZ5^tXNlv}v+=wJWQ z0;pZLywZ(>Ewn_T$0zj4p5t<)w9w5%Dh!|f;=(-=<{_dc`rX6s(3(ed7rQT#1Kt!# zM^1WLT6Yr26TYf69|`L4s*zuMmLcCU*Kf+v?U{ta+Z|V7}!n>y|qje{m~R4 zVJ2?lg(ih--Tyw3jGhc2zQwBaFY|VY%m(nrUU3=%2 z%F=n7O*m(dw)YtyQ1StR*|V zP+-+jKpjabBZe2FVmsDFZ11|Y{xJsmQq;9cLKR5a7s@e@sg0Y>G<#u$R)ZD znfYquuAP67zB;e}a`0e!uop2?D;6)(8QbMkgfCWpg8hq;E`@jOI#D1@U0mqogDQ%7 zuL8&2KV%n-&t=AHk69Zx9?V*A+kMv-=cLl$#TV5Nu4MAO|3srfM3U7>EikF{>-B7J zW^RZeYA!QfMGu+ONLRiy=7$DWr&)O~szg2}bjB`_bV!IqC zhn>-mSKwa3Ru&|*un9ucteno8u7Oiob8E}7oxg!np#`4qmA3eCzqKV}r0F|?cMAF( z?}AC+e%xI6Bj3ke&9yvgY2RbUqQWeeaj*GaU+=FlQy4k}9tC^awT4ozqHsIJkOwwq z?ax}%EX5DPLX7%VjKG0iZ!9PjcL-u3Ib;?Itip#v_SY@Lzu=TjwG|D_;sR=ioi|$p zzQqj|)RqV*u%u{biYe)x?~ZtlfQRZ>kTqoLj9Y6jsB$vRQ0ai?`2kzTTP?4Vg$L>lg!kszy61=d0YpbA|RZeY03 zTWTL&$`{~T!{SF}&d?a2abhhV!q+BFov`Xs2VpP@^zk4!R~pvSpsO*h*#J~0i&IV-wO8jWE26k7|TJ>q&K+`7lkmM&$VRz z(-X7r(O4-*!KhmT=e}8-F(jE9?3WGjj|dPM1X9Gd7E&KBXM)ku+WImb7FkDls(^Hj z)mQ;#-|pb#AjARa(JlkQ8Y7t)AyODi^`TX7wpeIEWVunLk58q#i>QCb^rsbn5-sfTvF*zOh6Bd$`EeRVyn?2Uuc+z5djhaT&*?NS#FKWIqwn(KYRm=l4^tt26A+#T+RPOvYfEagjzj z)oep2ARwut=O-;agW9)mAKN6WhT0kMb5z#!NkdBg#~buzoZ?(=JyKG4EKb;BR6#N- zM%hozP$Uqj9zTBEp2(m{k}4XO06fD^OVqGVwX5*roEGVD&QJ8{ll)|aX$Zk90DpD< zo&7M0*Hna;!~@tPAC*8IFa-|(>rJyaZryrLX`>h)EQEhTMgadGrf=Fw0-cSh7zv^} zu`+#J4s=*08^d;PH@EgR_6z&J*fv%h3G409b>k;YNT$*XE^?*4NPCK+9OA?QjKJ)V zjwv?JbjS|rgP zC@g&YXPo>uyF1?M|Uy+A6q=*lYtKPTo&9- zNYE;�FgozxUOq!qrY81uyaAt};pn91AAv>8jE8zKNPJ>%1Ms@9J48qDaSZ^;Nmu zEsQ>k$an`8Z#FX%e!WXlIjO^otfm04X5GPL-xH2-Ov(UtcOH&ABNerhC#<`DLwBiZ zASNPV*~y@l!cI%fJMbRdx9@P8P;R|nVHm$vn4L^ZIpX2b2WHIQ_ZOsUa<&OWB{QS; z;r9lR82_dF0&M8{dxu);wKTdxs+t`bm*&PVgOWb%sMi0;g9Ir-Y&FNECO~v3^ljH;` zD1@qV(G+=$WkxfJ?7;=6vuZAe>?Cc4fg~f~Lz5(#q!bzEdl(RT#DIzGPE|QNvRH6E z?M-IiSL8!Q;K(!cn$3KbpzP%-=Yr4GpxusD={bZ|r`rsOA8P9X;jERc7);3(<|;#& zeu7DfCXqO93eLG>R2(X=HPpso8zwb*!(89&wM3gIWC#cO7u{J47p7XL9&5D#V9Zqh z^5WzQ(LIcW7JeYux|jo5zzOaE6+RFHl?a$!2C zt)LcSTL3Xa&PBaX@Jd@K5}l%M!!$XDB#R}-uUW|nq?J1D0H`nbiHVHj9d;ykJbC)G zf*3#AG;>DXMsss{)U^sgO~*_N5G(uNTP`v801j671`t?J`fj$dnXt#jC7JS5DaG5% zs~_!hmq_|DqX~!Mr^0*hNUcCWw$l41UgvtP=ZSfm*tX9}ZC$ZvPofZT3LH1=Hap~` za{;}$N&3ID?jE5$_Y}KRm_`e~N)s3~F$k%Mps20f``9OJ$iMg38K`>r&>=wwTvL!` z3~I1$B`Q!aGTOzdHh_S-$M*GBM_Hz+S4)eOX$kJDa3xw9?!b zqUlBpi%_W*%W7|3yf$Iy1l?;#h5NK_CBMz8L1hK@8PV3~V-&3*utC8^VA%os?}_Oy+i z;E;ZD(s|J&NoQ2L?cM`prcTwibPHQ|Cni)UE0^yMDV%eD&vy?2h=u*+pwm;58d>)N zN3=Vs5Wf?Ryyd8}%Kwp*SQV;0O0d46Kj9k$=b~*3dm1LS!p1h3tM;WQ|1;}5Gj$+G zbePHDwNGX3x+|iGt?9b(UNKo};~7*{RV5-y@S5Odnkf&FAx@R&>71Xv zTi?;iDO4d0rBH-!B}%n$75PDm=1YC&N4DZ48*EuQ>FV$^04$4b+^}K^O?C^GlFXh#xbfOu+$?;ckMzck9(_wY}eNkG`XJj3455 z5@!XB_rwyzxMiht;@&Y)lQ%`%F&rL&d6bY^O`b91r0WU{OBO2pp&SyT#GUIH&sLTp z5&*p!D-EV`tgIgT)9Bo(i@YP$K3{_Q^ zxTQ|Q7n7@j3p85Mw}1bUv)esk_xCRrJ(Jqe=w~93?zoy_nvMZXTd7x(AOgO#jb$Zt6ddz_UsLgGg@)}CO2Je zO>e#M{fC-sN-A{wk-Hg*-O6>xzffujvLI!;dS{;(pDoHqlVqc|S+!1Ge(2D-mR;(5 z*_26<^9hj&X@Q^R9NY$A0;HKPaiY4IBO`@0IQYF7M3gL$LDFl5I2Q!=CTGBv7B}48 z__(9+eh;&`0%}qW^81LKreN)t8LBZaE>}DzL)}i^5E@6B{uq0SkdxnT-n{u3ETt=R zysxi`R*!apVG+I<6iSPeR|`~@3Q~mB2H7)e6GcH}xf*V-{bmzK%EAE4)z?-I6Z7`i zH@cFAR8cU`^Ps*twxV)e(H{zGFHG%gAsC@>L|7sh9x!H=fQE%{mF{l3QAEE8?HAfr zFOd>sFcP|{fl4pL#6PYFq*fauX8Nam@b@CoR{r^HFZM<3RDVT+hH8(ZVaDt#$aV^P;{G)UPetv!*D!o(80C0=$Oc zpBCrk?*2;Q-gL7$ZwUO`KR#$jn~8zm9tfw({5$=i7Xu9Cw~NaXoS2%wZkTQt$fTkj z`yp6^g)?Gs$|j}72y79_8sB&I;p{Jh$a8FIiVc$zpP%9|`2BO=;`X9rC&c%RBD#jb zVrDlgFlLsR{09I2$d_P2YIom4egtatH!N@AX^cSw z?&~&C{Ud<4r!7fLtjg7~sMv(^n+G=TT)1@Ix~<$rl;8es1%q||?LO!#SoA?bwzxOl z+EyrJ%X-$WmAQ)Ii#%EF+AJ|55VM*tn3N}_%0FKsoWx;d=ySN)a~SRN@CCw7s;HsS zO^h?dMrCN00}?W5xByF4*Jek()cL5+L|^>g`0T9#^gW323OIGR)HH%3?IWKqn(~?Z z(tpkf9c}I2Qc|ai3zcL_+3#Y|lR!*U)YRS;4`E?91C&wLcVR zmVyC(z-`0s_h>@eAy7|}-Ec%cY6h!b0y_>)9?B>ma5i5!#+15SVhYV`X z)$CnudVIJ7F>bP(|7M5UX{69z#I?l@rdVjrRP=ct7=-cj*x7kK z+7|YYa7u#HDtbWztFhGAmM|fVs5HdVp%v;XG6r(3X(>f%>CcadGcdglTq(1biu<~X zE4lMq?;Efo`raYt@;Th^&~&9YtcZ{%gxYTtdb#)o<=!DyXe%CQ_B&lNuo1rzl|KZy zU=s$E&o@2GB3qoCi{P((NZ}|&yG7&`Atq%iH|scZLN*t?Zg6@bo0pf9Gp2g8WkCku ziMHayKF4fr&BckE8yIBw_%i2!c6j@)wGR5Lb8~wh?lJgEL&YwdE6cJ@#+*@_s;q2a zY5!DERbOK8D+)3$BqLf9-78RfPd7=0;D&~tF4~GW>jEp+TyEOBqUJ-M~Q59>lvVxmJsk7szz%fh4GhODkiLYq8 zENcjwX?v-tF5r%uH7A_V%w$!BuCO$B_yd=!m%5%XyR(_PJKR@p&>#iLwdPms0t0$b zO<(78+Q2hMA89M2GE>viS=s*1Ev(ehu-vh(1ygIf%-x!1Se2Ger{ApfVNH9VewePM zxj-+;*&al;E4WjDr5xe`k!r+=v6lR-5VA&mkBN>J zU*0%<^yr101?CH$f&MHh3FsGichx5I7jF(zOB6vlov^%%OS~8)ZrEUoSdc|@Y2`(A zPa7sPiEhq_rf191R|~Nxc>C80@-F#io3)|nrfio#6pOYpG%%9OI1FuD{x(~jvy{eT z;+8C4{E!%6Up_8MIGzEZ>uqa1-`cz|EQmcX#04t%@bb}>Q}3J=ZKn-CRKkh+m!m1DxISW@jryBf~>WW&T+@7;ypcyW!P!-k~INsOhwm zj8KsUzo$&Kus^FK;CYJY=%CE#{k_#KLr#H(E-5`B6IiR08~Dy%bm;%9q%@b=hur6# zOAAhIG#k2`+~w{HKBmI;&oeR^#@+TiV&|?lRJ2=e5)w_0=4&e316roak#y)@JT6ZY z0=4p9-Mq^0$CaN+6igg;*&}kaG&xB2%*g|>SR$iVNcr{Zgl{e8Z()-c#ErQKY$+_! zg#wXgi!HGq?y|vq4;*+kClcsIrNQq^$5+T+CfTctL36@b^y)n%WgZmJi(|nsPLL>pSZK* z=hV{Dw7~Xc^fqtWO9+D`ROFsjI9=VGu;|2{&?%5t;%?=nyqZ53E+{}jKOm{#=-Qt@ ze{P8ls{EMhJpcStC6!L2t?l@d*SzM-bt3tAOd)=q-qbHiraE^%;rCOdG9&#p3O zoownM?5NHZ7lKj>a4>nMoT1DkkltHO_m0XGV{_FGRS!G;wAwTeAD@5(`SlToN` z?{3%B`lqJ#*5V}fBA!fPvryKEUP{XXqgr@2kl#x}+6 zn*t41pfZ7*Q)w8X=oc0}CTaCl3?j8BwGhacmXBQW#n(*m2xMbNj7?S*9OGEYs@9MV z9;m)3V(!PUU$4b?b3-F?Vo--H`=OJgtT?jt^8(MDSz=WAX1bch<%}Cq6_v4bs}m%` zCJB0r13&h?{*R*A(8hk9I&~5N#2`7~lQ+K>^b9nx@E!&i&4I6Ljem0Ro7OKzZRmyD z*9yfCFJ5@MJOL5_*lvy+ds9mddV%>!X#@?)(;*)Wb?O%GS6_hk6 zE_-xfg&i3jZlni{o8FW(ot-Y4Gm^CW_bJ)RetyE;)>}XD*srr<476ZwmQ~Fbn1$uZ z5B7_QHKCXEm!&ET1s%PLx;;;uY*A73S59xImZ;LfX0)bXgwN~gYD3+|hj>(W9gy&D z*i+-;X&rM_M&>p~90v!MwogmoR8x_e)9;wY(<3>(k3D>e4jMoL#xdMo73QREPdB=?tS*xWc&J0eiObZzT0u)(6_IzM}+PR zAj$X|aXT#~rK1oSapXou8#y}86jj@aGnd>H>zBsLJzsiTd}zv~&^J#5NLTAt)n8l; z(=3#b`+ta0?Ifbow|W_f>eetQ7@3#z-Bqat1+sk8W9q(XakMUI0EX#@KTfRBIAhDS z8Taq9!#W`Eq`NvKY|UsCaFz!{NlVMC%<<6PEzmGe#TT1qZ2gMy5Fp?PzN0t}`Q9^I zbEQC!MgBrUCtRyQ6a%YRKBXLop77Y3)PN${YMM%<_f)JXN{Yz-(8m$oD>^ii}h8D7&#?AXno(q z?n}3E`ScS~MUM+!{TOn`KSBOIkXVw-=MlzJ{bL6#+H%@S|I0|B%ahaXV@~wzac&!BSdT&L z7o)EeOw9ftwx1sTy%ZVC1(&Dz!9J#Dmdvw9&lROu=H%?GHKCUs;+!C_a<=iMH$VJm zJe#O=#^2xUWpBe%Oz#z-&Vk6W{!mcrtpJNtJ&uu`K-(5rbvcea~1tDnCSLz-uQ15o};OE9p-CT!$KIc+1d&Z1? zjaC!hb-$C5ySty=^7P!&+>8?~r7Jl|$NL3^ERj4dx38bm?GbFH%dB-5&q*-LB|5L`BGzE5gL1iF5!3=frKj79Zag zRi^(e9VdI%hZk9B>5e({{Y~69rHh2b(UUy)Nc`tk5{?Fe>aP0xGdi=&aaCu}PlA?1 zU2AS&F(LkU8gTy$vr_!qJ|}#AY3W+P_&V-mHt=ZmxQxq@@&k9c~1|JuA~`Ep$Bbc;9+^q|bw6EgC!cmAxn@1{pOT#a&(%5`K;TIl;hfJymJlaFnr2 zg2*$FSB)~bxqwnD0qs3C0jewc!>L-n;eH=n8b8Es^FrrimBd`NiE=g7@3#L6Ig;of z7#PYgu_6=DTDmx~VSJ-%z^x zN?!82P@U9wuFF=Z1N;P;Z65eJ`SYfEfkP({VB!y7pcke*@@{oZ(g%PKOIY2Bav2#O zv4BFBQ~|615dAb53`iHv?tAa6f5tSsE-Sw++)&wmt8C8e3Yquqpe9Ix#vw}OZEWFoZf8O9_2GZke z|K8*~qnfbzSbIdGSjW&S!$ zcXC0-9X}~d%`YOAJ9(tj(K%SWXt6S~C_%@UN-T@AbI33VjEkq-sQBI+$+A%47*tpC z$eRbttFE0;?&<40UCY53)Eu8Crw=t>gxj$jvke}bj!Jdv%;1vvmKJ2>Nz>r#PafH| zYxV8s%!S6v;PFTpE{ey(0~3#hXbd?#Wm~T*PZ&X|z`o!mh*=JQ-MRKX$`GHf)jaufxB9!QN@cUFkVE#677+Y?~cS4LJ9Qap*V2-$lksH(0d z_SieFLyD<)WvMiGG5OQ%*TlM5q(QU> zujV{^<<20u)1NP_m47p<%bU`^I#--JUvcVt$HUVzp<21a+=p{x6t}k*J=)w~8;fxx zMre@RvAk@XM%BX7T|-%-dVQrura5=bW`)0xmzQE#`ks#u_6>HOJ$CG~@8)XYPOt!X zz7&eCo;3xx*F}!ZC8Qh{aqP*L`|_%1W|=)nO|6K_+j`Smd1i!@@N zY`woeY5@ITZf?7`HBt`t zmzK6LPKU%WWGf$z6}3yij|gSH*Y-peA)o4Kd;7%#E04P!PHVJ-gb?r z(|>iBkT^@@{s;t`d71Cn==QXhM_|V>dimMOYSk3|&UkU(0*5(lD_^(XA>-xZlgE26 z%l~Se4kAk8kE9C(5z zu;&TgI+bj`Vpl_4h>I=c5$RoD^ExJGn%%nx5ng7SKXw!ur&~+K`nY&}AU7M^f=d+c=gWtojjZQ8RcC<*+@!}Pkceql zVx_M~X5tr`8FJVgMnc6q>2>>q&osmtHs16oMmec+gJwJ50%z8}FnUGn^II_oM@P1s z@7X`v+qQA{OO1}}BnJ%&Ly7gm?lVjwu4m_=+kzI3xLbg%$y~Wi-l!?biMiwJfBl9C zo^T}k_x$IZ;A7&JI;h!~`!8%ti4b%1@-*hE)zRZy?lYs{70Gzg?EZO9zfW%ssI073 z`k;e8M`8wn+PRD=&#|#DD=V#u!%Igj@P37Gc*D*#D7;wi2kG&IY^iN8GX_pDd^=`h z{I5GK1VfHxY~VqKlden(|K&u~p#efGVcMLUF}Vf_(+0ZP$J9|V`c(h=@F8Z9lTfBY zuX>BD29fc$Ua8XyR9g4{^H2XlgDig~nYqe7Jx?!HC>d;CK{X4yA;ayDj__44S$AY- z*{*8Rm4=V|dwm&X$C>fMVU}$K-JLqfe>N4_Zq5Eg#x=X8{!Py#vq|`HQGnICJ8u2j zM91iooXB(yG3|Ls3N3Tr5~2b>58HlT=LTAjxoA%_VPgQl#*0rRUy4-XewCNZ$eR_O zce9?}$!OPg;)&VqovwlWk|>=OpPuz|+4Rctpz?2aeupJj+>?sm0q{TAHHq(0Sbs0E zq}DA$b|d*q|Nc8SSv>?1G}NE7mTn`Zymd8s@d8(j9Wt<&^xqnQ^7x%yfM@W5s=T7L z^F&&DE4b{Dk=}vpR_{=4t*UJQblK9S-mBKp(sp=ZQBLq;KdeQjy_n+`bW2gWb(!mm zyN^ESPm`9DIE5|TedL3es15V|J<>B@x-Gk(px$$Db1BO$o8vvQ*~81L?#EGy2P4o$ z1e`i`jic2H{Y^@0?vyz-9=ny_Ot(98zJewj%G@EcXM%!UjMROzn)i?+MCO~Gt~~cE zd*lTXE95=Ia8h5@>KoRB3*td++0lzd-xIx3_Bx|fmi*n=a(3*O(tZDoLOj6nlqH1F ze!U2@W#8mfl+z@jJd`D4)v-QJ(bW`fR1bo)7BC77{Oz8aWsH%Y;4= zxG;ov261BS+pOuEZ}0ZoczkkYrN#nOJ5c;17vpGiibYp8`40gz+j{<%nA0)T;g--n zn}#G>Q4)o%x_syl-iZ}lK!L39!u(2)|Es+>kLx*Y|NcMBV1`jMBbTyd#*%H6Jz5PT zgGAb-ZAemSrF|Wj;Tj@}vJ_?_lFHUzGGi2?NQrhTZ7S{4Lih7*uIu{#=JELbbwBR= zzq>zXNPRx<&-;Cz$9WvD*Xwv4^FMspwR}hsGCtXRr>+)sUjXZoy+=rCJ)LP{U!m~8 zW9x~V+QHAmUn)vVTtb)@fT*;dmcs6?jcNz&R3wfEfBspQBQd_e^}AK_lKYP4;5c?y zK2^OzG(B+c%%eQ-ZGMXMQ7rgP>!GioqKe8EkyWXEXz1x7MP#bf)2;nAwg-ZbpFFwK zb23KW&+QWLFFAG2kpg(l$zgZUbL*2xRjq9pru(g*ruE()&+|Xnf&`!-x0112Ex-5D z*{|DGq{RT6)8Ri(Sfj0-?lNF-YvKt?(`D@NJ_K&d4Iaiu*L4Rwee?_jgbx3_*e`+3ijR6Icf7$ zm>*%C=$5%G9XkcD>SL=ZCF+}g@zCc!j;MUob>}1L%X{kcL!E&%C|c57-j{BF`p-2h zF3izB_|ejBE}4UqHZEOS7mU-!o-(O3{xxgm_ulMYrtg|N0Hzj4FHLgBKUer!yL|Ar z4Ks4j^OqpI+nbIQxFBW~B|Aao$0b1ltq1e>A>&T6(f@ z)UzD6z1do~u7V!Ne8pN#hvA&i2M->+@jmE5k9SPYqeuF*(zp^`ONu*MTXo?o3MpjN z+aE3LrC?cvq`0KmeMHe7QNoOyFv0GYw*?XG?oZa`sbQ*~Z5MQ>L`Hfq3-N4*rMF?M zL-5twT3=sZz~P&Ax8FSK_u;33+P|qZ&H#bhpufcEg;t~o_2;&kzwQlk?k^A)D?^g=`#NnVuLg(nhIe_;0twy zT7PO2CJP_l7V+)2&)U+U)uQoc3jO&J`vq0RY0&5oNq{o6xlaD6Ur9x!cdeClbu4L*1ss-3{bo*M($FisUr{NNrF&-X znv3WKR^8vY#L?{8ZWoup=6Z)HVv*!dA)ScgUKnSQ6>cDjxc$mnVYkEC?}yH_fA+NB zlw!>QAB>!`R?%rn7^`scyGFb}hD5qj37lGg5SBW>+RCp;TtT3mro6EIrD zpYOr`FVeEB`!%Uqsg@l3asw20zjCY9kjz{vPMOPmd@W~M$?tdGBnjRk&%~jvX6YWf zAVpc0H(dbGBF_YJq^S0*QDTq}-H4kAxR%lAU<*fc9Qu^)T;nB%s89)Qi`*W{?TPvR z*pyJ`mHR)t9{A&XTD|J?4QoY0jl`A4ID<&7(frbSv|QIZ!&^TNX}3t-GnxYUdP0k+ z|FSkm3YzLWhL#-LCx|ozf z_i8j~;jF1s?=HyQekgtK@3{|eNUl<)ZfZ$yKcbf{G#H2(Hj%&R^9hC;sh3o%WZ)77 z>%9VRu*A+j8erw@qkR^dJsTpC#2G?OM=<%l~ic+O`!I2)%%I7;0 zmews;-O{HWw;PRElDhOOhzG5fuh$-*`RZLe*<<8(>j$Yfb_3tnwD{Y583<|!JT!Lm z^(jSFj6hFdqUnLLx0LKeFW`z34E!ouFuco~5%ht6it%nKn=L3nxLb26%T_{WC(an* zz1K$mcQ|qb$OMxwya@tD?;InWsVz7pJTH6{=_P?e2=&jWpu$gGV z9aryovNE-=In4kJL;A?0-P&VhQRwVi??xea3p8S@-0E$grN@#gX+6j#8x;YXIig8w zvW3$L9}D8PP{llJyuG9McWZWkqJruGc0^jvEXM1x8yr;;pe6y&Y!b;ZvNySVVzd|& zK7OPCs(<;d=g9|))o09D4ZoJxc`AD>@3Z(|+NrpDi9}+wFyfs1as8bO1^}k6P;i)x zGd7tnZ?VHwDGUtx)Z2ar+U=wUlycv2+(AtfX~Sh&d>)@#nEmQHHRn83JAN7i#f%-d z+Il7iKw(voM(<@Yv=yYp%W3;r%PoOFpmvxIU}yh6>cR!td-FA=ON+JL8(fYb_`zk; zgGY}Fz!mTaHK@97Y|KDj-S9Q?dr$EBZCQ(<4^Aj=DVT2daHZ4nx8LwT;rKQ;`*&ZR z58`CJQ^Db6XeBWuC&RNf@x9deUojMZ#)x3<_AwNjR*v<(f5n6{M`|Ni==LUDj)r+3C z|306amVV_I?kvZA8wr|&n&KF@$9ag~br8r&S#lB9DJM&_v$JD+dU*6(6h}&r8gcJUOvr<<39g<v~S+6oJC%8W!h{`-PicMleBdzik7upViBmk(V0+#0;)>0<1_z5~v>5o9JF>;JRb z$CZYL@8Yj5AF^Z1xA{-ISl9El#?E8_)((?<4`glMl!I$(TFVPQ4m|c!6 zPo4t5BvE6Ck$V5TyID#bnA7wtg_wuOclm^7>t)wnjp=-M__Gtw1xmHe@|hxxmZ!gr zZmrn?FQ8SjU18d^>O)~cZqw$AIqMH|a}&2Nl4yr(`kN!3xE0t4X+L*V5fD(B!(Ay7 zF0Ddz3n7PZYbpry0ev%63Hy~-eMwL9ox?BYYdWe7SZ(Qi{s$lX7Tu~j$#5dF3D(*D z>5&I>@Ec{+t&@jZd;2gG@tpIvgI5;ff0NXgqtxCm_yH-Y#K1E&(b%`I6T&jmSZg`T z*kyvq3Dp=PqaD?~48y#7pzHor0>MeVi zx%Fa!^l#tM!C2<-(OJtjhpUJ+ya2b#nj1MV9Fn$z9msGV>5H(+&HQ% z*ZDUy3@RVAsm%7S?HO}Pmx?Gyf>#P$zTT1To_Nc0M2MXiH(t~oCcO(YqLS=Oo`-(m zOlq5|`lc3!c9;G;ONi9O$FqcIix#ly{@~h>i1+>dffZw`*(e8kdVIbCU*_T52)x;t zxz+E~IqA#Pb%-E!W9rM7MMXt*?o2>T)!lnTNQmU-+S-WW zS>B|X^Vc49+o$$1ov!y6M3LNOsb(*21A|U3+i#|wdow(IV&unGV+7Pn6Tgj|n8M14 z_Tu~*0ecT@x4v^oj4zb2UhKTJcW5tN@_W)AZl$ikZDOm8ok|=yRg%T8Ch9v+X}qge zU@?S_80&Fl%?cB=cf0Cya_9b0X64oF^X|71QWlR+4x%QnzYqM@uZ_cgg3=y(Lob0m zwjA46ixMz~JV7eyr?f4cECzJ8QRAD{55)VwUxz$>(ZR=$ffpAe8;*R@um7Vk`B0op z&mM~9ilLP+C@KP9yTn|09f$^Te3M~_5kS>88D)$C=n%dedz&mL3aBq@<`T`pB=YT! z)n6Xno>+(H;vuib*rdST-PJTG3ItB~YJI;cCzt^n!+pQisgJ9h&#&};^dBuizYn#T zJ`T-E#R0V^6_+Gg`e{Hz3^44w)FE3js`gZUdg@ddBW1;3@ARulRR7_+M&!1E#hy(G zMWy9dDm-A3WAZ-9bc_aC?zTK}pwdA~bo<&zOO4K(Z(Lwy^u@eqo+ar&{B`?OygDwM z_|nQ!2THNzplf&d@Qv37caE8~9Nu`xkqdV4)Pu!>f?!V7E0yytNy9^MK_Ac`nTHd~)Z z$ZRJ1l_xt};uXI+F$6+m^<>|I8vANw;*sKyJC7OSq(f$1w@EX1-g{OW$2h9dH0ct_7>mBt z!f^q64~wk{t#$bIRl|cP=9)bBn2<~-$QFoUt?T{?V0<}$#N^;#bRPMnTv8c+JqO0D zIHG}bUQ(%_0@InduDDgWT_z;D#{zA}T+7n+rlue{`2EL`F((yY*JIr{Z)?R*$<%K- z8!t@WqjE?j{h*r-znWShhqvk7*$@1I;&c`xITtK5SVLJuABDCNM_&Ts>MBby2{q*? z$N%cV?8X}|E4B8lci@brN@-57_JUbnq(+}`Pg-`~olyJKtH)>R{ywuAqQ;P5re}}9 ztPZ0EeKM*!D?h5-Pe0b6;J|^SUa=wwWrcv@V8YYZWqThyckbK=x@IQa%SAaWt@bY9 z*{B7%XVnFwC^y&qgyP_?ci0i7x_>kIrej6p+!Lex9`0F==b@}5a)r&_6KD>J;KZ;rS|OZn9*s1ZjhK3OT9Qn{&ht<<)DOvWUGplfSoD!-mOu+Z2*44qHyKi>+Q6V1I( zy!!A~2_Ju}e%X7_>xQzlq-SGB3#OhWW4&i?u%cN|+mx4IIB5}9I;qBGBj$vWb(zQC z?!7|&bk)7wQ|yEV8cT?vpb7@W{z9CxB5gkoBSF}S;#2E-kb2iiq)UglSwA}Yi+G#^ zzz#2dp~M=m6wCn1{HZ}|+hiK!o;=wY>F>_cL68|qN+_E^qmr>YS+|4Hc}Zm|Jvt^? z^shYiSeKJ?19A@V0!zRr(cjgkUw!uh`ewY*cV|%Adv!P4r=OonyB+Z)pO({YG8TBE zR{i84!pezyod5`9w?94EBl5l-x{-;P9Q{f?4VLg|=7$?335~xTkYA9hIeO&Cd?6~( z@1yo}^H6HITtS9!k_hu1Nx_j);K%MW$n^8g(Wq1hD%D!E9R0_0K&=P|%{g<{!5gJF zvrK#p{#Cu;S@&-EYeVi(M0xB^pwou@i-%hKLvOa!N=n1ETXuiwZ{Lt}k~mO&IabJ+ zIZt&m4@C&9uPZ52ZmeBvYwLXZJS)*;vSP(vy=>DRrw>%tX=Hx^>X^WexM+FIUE8f1 z6})qFoj6iEN_Xc)BU8yz&8xnMY3TT|&W|2Ec#3TYjV9AB71LUjLWcFtxRd5xj~3=F zEz}4&BMo3b?v#RROGTyyjQ<{4F)^FWUZuV3BOD?g2*#ndBRmy zJk0R8^9@%JJ3%cwXkABtO1c>#He*g~G|)K4dC^nYIE{RCOtT( zW+J*O|I^(|U;ntCbQcV|A+V!4DW2H}u&fK>6c6K;XSRV9uXJ^liFWkPmZPRFWR z+~^88BoB;I7Yh>Y5?gp3W#3J^e@RS*&st=%=utSfTK*|{DS9jhnTxIocIe>ompSqF zE^u(qt_2z2{rsiZ>z`)?R*7xGr8iYvaO|EOhOquapJCt z(&sCCS@#88#vK1}vx(A@VW;ZdsuYT{#M92}jz_u!8dM1E>nL~T?AbU<po!`Kd z540>3m5GR;GrL&kVnK=-@HyQhkcbzOJ#bCCbKLiZCG2{cUza&7;LJIjW5-!R#{ck2 ztvaF_81fSg#xLC0dg4j}Z0_8?W@34BYNMj4>n^_+t-Tr=_mx!WW0&01(_=Q~ zCelL?uGJO3-4T$RO`BF_B3KeASXRjUW-pppI*C)V>tCB=MaALqjBdCBe9~BEYe7_I zW4|EIQl!AMk+Lu=_rYRk1!ZXbujUW$DW+R70pC#JTUKw0@!2ihB%%2qeKa*~+Wh&t z0>5uKNl>$+!#K-et57&14H;zK5okrRmuBH6M_X<&R>=%edEJpr`gyuGCR+F$n*Oo~ zH3qMz&)z^pHVTb}4gzKsTxUhI@q5$}itp2gyod)W^b%m{w(P?bF0mqTf72!H+DMQU z(w3>KkJNFgThNuUY-*kly~T18*=8!UnPooz{5+*YgLpu+^O`w4J7#3=xvIjaT^O2J z%P_G`H>0;9#tSsb00fk@b|x?i5w3s(IDGAt9;8u?Add=@g?-pfzL+_ChZ!B>w6lxP*fG&Rf%672!Umm5!XCM5#n*8PL*{E*%$qiSx((-F zs51?^dLl~jGjyn!D{FtLga4^`EmDtI)0F&+t%%lFfmFMzv{sA-2J<4FaIKASAAfH$ zQAS$5<7l~0`})rMXRF0B#ow1@dg9ozUqDIhsA5lpBj~jo2NpT^(wX=Pa+|$B$p2YM-s3@aDAIK9ODH0Fz2u z*qRwlE0~o;-6MS{E{-S}<`5E55ma(eu7+ zgsikX1K`|Z=6xY4#vIR2#9D)<5o>Mwj2WUYC#%Vi8tY}E(^OZpDR&XZs)+&5> zeMsq>q zAk`K-mYa%gs#o>;>&PF3rWm;J7XD*7F9gg*jDjyeF1&l7%$Ni%mb41E@UyW4F3&hH zT$@XcxkCcdU~GiYYz660iAD*5kJs+^t`G}*tVg_J~ootu=jbdmpu z_Q`_!tZ$eI?e&E7r?f$n9fsWahkRLyr1VSS2LRNTQ~MCu!WRGqP1Vu6$>@~~E&Sgn z;HSDv4Fka?I|(>ht%l`(efB_@r57zjkWg{TskjE(oKUUL@3#WNog0&GHkX!s_fqEr zdX{_LF;;j`=M#K^v$}!V@=r!7rJu(p>oQ6OK+h*9eu{}IRG+jAdV|CwRvi5V`nBuV z8{oZz1hm1eL^+>cPVW9o8oaH5_X_X=Sd2i`Bx{~i<#KFNvO^Okr4yMu7x-pipflbd z*|9QXCYb}^WNZBg#`<<3O*%D-OWX3-Da~R=_eD$64#(sDz6rF9%8v8E)>w~i9iey& zQu?=zfLekG0v)*(_&7KBIsJnoPOJ&@TSB@Vp?KfA1{zbGtk8}kkpbmjn&5Fqj#yMU z!7=?8kt+Uq5@FI{AW@2Xw4VO)Z`v-jB4S7X}Y{x;<`uu-(5gR+G85%g^Z_Ryb zoO)beo$$l<)Zw6WiN0}mA$gKj)!JG&O7fG!wjO@bb@N!*>0QJ^C#^_8ABNd|L!xCpFn)AeIDi&@u%8T7Ya#n;2@dS;ECMC*kqUJFQ zOfumii1irXG5V_k!Px7T3IU-&7#B2i@^%!b>w;)4-i}*NP&W!+ilT6ptW&GZg7IRW zbLn*s(*e2ZOI<%pw`RP9;v{!$g-XUi(v4bBbk#T#we7NmlTQpp4VUWH)X?hZZ~>%K zq}oQk!;cOe1!X;DGsRvc(NsssXjHGh8!IOl4J|Nx_UwmoUyq+ihdCutfTm^Kc*7{=0`dazYm- z8Xw$Ld+hznU$)*!6IeZ~(ZgHTe9WXtv3$xaXrVoeTR45<2EYl})Em!sjrwX}chvbJ zAcweYd2ekoW@|Ax=4Jx0*;%ExMBgHkvC{Ho%9&-g8w*o@s4q>q`E1~@qXN~M2#0vP zwcVfU&9f{r7#s0de@(8ifxBaoQfG~@$DDu2E0JtPhe|rQMCsH*h@BF-jXb(St6}<6u>9hcd z!ouFq#q*H)9ELHSo^QE7d(4#~rY-1MjbMJ&p5F#kA)-dHUveyclkGUjMS`VfUZVnK z1!A0X)~jJwP!}4gqm!2{)byuqEppO+!IhGJ)xP{amDacSSac>%lPnGSSRW%Z@}{MQ zw{MP>CPKc|S4*FR{XR#G2&1d}W|3#|snv@Q#IT1h0f5vP|K4t&R2IehuE+T`FXV$n z<{N^5&b=UD4`6Z^Sv`gx(r>4n7pIDT{W+wby|MQ0>`CsjO5czAq48NlJWr5i=1`Q!Or&f z-Cn8_JS-4diq!jhp)VopIThC;R7F7>ic_zEr*2^q=CcslTv6~QAkpTNrr+B=TU?$lCbksp{=zSsIi9(e+bzDZg(<^v$JoCMi#t9e6((vB(wI68fEh=)K3?~%8Z`SJQba{LmEN(EjRNi0u zlk=70#hCj#qKO%i^xMm$`P{G$>oUuv$)xb4VUUy60)qt=E28jlO6mx; z8)i(LN0=T9wp{!?3d=dsEhCcrS0e+FXI+5hiS@@4+^OWxs@S8#Fdx$=`q-E5oS2k z6bGcbj`J#ZBteO#dYq+|nTwTI|3*w{C3$bB+0|zqlo7_ z7Fip)d2L!xbc{%C1nMHdy*>wZdSLQJ&H}BcYVfVho!ELbp55Q40cI5d<9716if`}U zK>^ZHXk;NuAe|W#OXXe!-lHG0fjW2ISb7VapVC*4jk^Ft#ykMMy1j zH*}rGwQ#yUWu?TjYG0Ojk;IAB7Btp%=*DX`kL9;4*I>X-ES15QEYTeyt<^q&cHbhi zzwMB7`ohBPfc_m&`1K%w`Rvm;alopomYmqSHuZS5Po8>r-Ni8Twio}HMDbC?l^=d> zipoCZJvQXz5IUo+feSU6JnD?Gc8W+Rn6a!S zjusnJmXKh)&2i3Z{L-86MYH1SI1}BX?KiY{cNd9#7VOQ2Oz-LO+8QFxx=SLQ@+?1--;G>4^MDC)d#sqOlfbKG2|CA+GtU(mL$ez ze;GY68xgiRs7X{WVBoBcbE6QHL^xa`OKYge%#`=4`+6+m{w8N%TsuTDG$|`NxImUZ zG0|!0#Wxo-6iq1OqV3F)I5LGqNxi#tcad4k0A^UrcI9keAbQ4nIqxQBCHLLDj^8|t z*wR4fORvRr-$ht$ADFMuJ=871avMZ(sz>XucV$&AZl^Y$vOatL>b1t)oPP?f@rQ|P zUF_1DSV?m*INut0vv4zwEb^|{Fo+{fD@Z;TphLyghZk`r6A%!Aa{tSQ;)EE16djQF z`kuV##;&~PafJ%Zlf5DY3kW=ZKVjEPqjd8gRjQg)zw##QRPTFxGf^$F;y(CU z44n-*OP^h+%WQag^U~9dtCwQ#y@YOE8>(4Cx$lzlo^J8T-PsY=Cr~xzEZbx47H&gj z%)VX3q&-nKiDGtLm5_4C`@H*Bv^5ruEkC8=dWBOi3Xw-jLEnw~!i5B9Ri;3+kD0T|q8hpWQ;0S_5FI!lNfVu?NC18c7d!}$B z*ZWV=q`W&uT2fvWQkAPWHejeVkZUhqu^nVeVtj@Q-KsV739UlbPwbBZI4n*y`0S_= ztk~kMzke+D4+YLf`it&*P-pWy>o~UbpiIY(853gnmc>XeJ2Qh?T!r~!an9rjhOqt& zg)6XAh?vao@db61%5G=#+R{-WWvPupfE|ovFL}`kq)8q5zmuB-C7{F$7Kp#6Pow%* znnF7vjwI4?JES(fg!crfAi=`W$hDDHLa>Cfh~UVNuAmSWvNm2dl8-T)2Io0HJ1YlM zJ92?`H0>LGmYfKn$q42@=V_bzeVEJevtTM43DLP^}kVWhg?TBXio^WiU3RT(V-&Xyg0nts) zKvxX0V<*)#9)-uC!7tHc+VpVodu-V_xD0mgtUF=nMMWY2$L1P|X)OdoA?&7qO>FXq zLWyKl*P8FX)o(%B=G-uBvgFP(%J2d(qfYLxID54Y#YF!4s{paNHR~`;zLpTY@aQ-p zY@uJv77$6`kTpYJf}RLLDwmj#2U3&MvEQh;2e&HD8&b7tRj*$cAOy_ld~Q_U z&uzrrR%W#e(X5NCJKz|d;ulh6+?Mi^Zdqt~uuA4y)k;=>Kux<+A|gYN%H-rPLeVR|}_HxgIn{!O*2+1Aa(O4S zNK?w=@qLcr)L4rcMe$sv2n+P))LeVde+`7d>aOe<3aSDi#|{954J?mM^o`fd2YxW% z#-9M%!?#DTX=1<^ba0!z;G`6d*KF&RT-kFHUBT`Q8!^J7eOo%mh)(S5kK1kv+?r6? zD{8iBS%CG1WG1t>+`3GFfM%L2DB@^0rYgmQ6Wj6xRS=Uxu8`Ws(rp%l_94aa!UKBe zW2^wA2owT!7IA^dYPDm>Y$_XUO5uVYooGK-EtCY=Z7MJ2ByU&TSSazxBfE)zlsB2- zot&JT3=N;37?w)5wvkL)Y#r;g?}nTW1tsO(&6AuKPnog;{4RBYX3N_+(Vdr%i1ZdG z*TqYhaw;kY9{ye^q*SNpMT9-bHKIWxz4<*AOsxza6j(~ymFJOjz9?3fRG+|E&6ZVO zeA7Az^WV6X+%i_Im?#Q5k*=oSwZ?0zkV<i38|9C`qC*d(pwmEky}V~GKl|yZ zQF@4fcR4B(mjRKDpfc?pOU8&W0tZ^isQZ$YLnJ}C6J8ekn4ER+_hw;rg*y^ufLJ+P zqrGF19tf8KrjY_xq_3mvbMaIAMVmSg;aXsOr&Y$XzuMxAkeJ2ws}o5NMe<_zRT;rqjKdRZO@8GI zK%aJg`;Ob-+fRqse!?mIO%=e@@e$$|0&a@@YcvXpnM<&pxXbry1^1w!kX3S*cfG$` zfW?qAc5>rnWd8b{Ya1v~K)Scdjr;P;VDU12;6IXp19Qz%^=V(ABWKh1-~S7yilzbmBR0Eibxs6B|ibKY&l`!g9k|?9_kSbAvR)(~0IBE^ERm&~yN`u_XKm zGbXVgs7{N-1| zB+IvdXH<}YfWy~1-M$J0OTpcF=dyTX$q_X&3q{AOX3!Br*Gsi2eA64ZB{STH z03_dB9U`;nb~#rdxo^fHMSOqNw3WZz8I1|S3C@VYTQ89~tspo+K;jNSO?zi>1{hs? z>FS?JoJKy#%L6!a`Xf8-5}-jO3aDi6ZTySKyVsy=O`)7t!q{ zxo^$E-ic`MG7LT&n3po$NS;ZGy#^|*u5-ke02^l@{OcS$4Lg#np1O1?()t7W55ywLg@=<23KYIYCLZvk@<=UD$sc`;}5dZ zQJOyBJ4;~RpTxx9z5+@i!f?5~lqC64MG1Gmat9fKfrw4SgtRw}kV-sq1B`)$g)nco z4yPIO@zZ(jiDCebXrd8psmNfQoURcJ#uDB|j}zwvD%Y;Ad0ZXM01P^O0@#VL{<__5 zbBt_hxG*DV>65`1AWUFkde|K2C&jSo1U#+JQ#Z|2S2shAj+=%a^`z*vVoy@~7~0a9 z!3R!PHN4cca+A=W$-KDOu4v#hJsnch7A+x z4_rtWEwbWE2XJZd5C8L@<+rOFn+C7_f4;}PJ1q9f-@m@HK%xHcZ#jG`?6R?D@Tg(dC`CVrS#-~@9clR_3!`t z*ZgnXn*Y4U^4s5We*gXJ-$%sXzuEJp*!+M0`agdw{^!~K=birF1&TZI|Ff9>^lvJ@ zF8=M#SN@GQZ)Y9}nEu~|&v!Tfe^=@M?SC4%PfO*!*Ee}0WJuPm()liK#SeS`7YPph Axc~qF literal 0 HcmV?d00001 From 6dcdd6536456158667747f724d6bd3a2ceaa8d88 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 27 May 2026 08:46:23 +0200 Subject: [PATCH 152/289] ci : only run docker jobs when pushed to master [no ci] (#3828) --- .github/workflows/docker.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6c0de0ece70..c5162dc8251 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,7 +1,6 @@ name: Publish Docker image on: - pull_request: push: branches: - master @@ -9,7 +8,6 @@ on: jobs: push_to_registry: name: Push Docker image to Docker Hub - if: github.event.pull_request.draft == false runs-on: ubuntu-22.04 env: From f6e617bab7843d86d94237f9791ee41d524333f9 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 07:21:25 +0200 Subject: [PATCH 153/289] ci : set GGML_NATIVE=OFF for bindings-java (#3830) * ci : set GGML_NATIVE=OFF for bindings-java This commit attempts to address an issue with the bindings-java job which is currently failing. I've not been able to reproduce this locally my windows machine and I suspect that what might be happning is that windows job compiles on a runner where it has different CPU features, for example AVX512 and when this dll is used on a different runner that does not have that feature it will crash. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/26496174929/job/78059073255?pr=3829 * ci : also disable BMI2 --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7ace04e1207..aaaa8fe5826 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -640,6 +640,8 @@ jobs: -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF - name: Build run: | From 9186e2453bdd051854b17cfb0d068f629663e114 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 12:09:13 +0200 Subject: [PATCH 154/289] ci : renable arm64 docker builds (#3832) This commit re-enables the arm64 docker images builds which were removed in Commit 9366544991bfee59c927e7c23b1861c6c762e708 ("ci : fix arm builds"). It also uses ubuntu-24.04-arm as the runner which enables us to avoid QEMU. Resolves: https://github.com/ggml-org/whisper.cpp/issues/2859 --- .github/workflows/docker.yml | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c5162dc8251..9e07f7b2292 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -9,28 +9,25 @@ jobs: push_to_registry: name: Push Docker image to Docker Hub - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.config.runs_on }} env: COMMIT_SHA: ${{ github.sha }} strategy: fail-fast: false matrix: config: - - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } - - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } - - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } - - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } - - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64" } + - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-arm64", dockerfile: ".devops/main.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } + - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan-arm64", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } steps: - name: Check out the repo uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - with: - image: tonistiigi/binfmt:qemu-v7.0.0-28 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 From f41562bdd6f3ed19ef352e93f54bdf829771e59d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 14:41:48 +0200 Subject: [PATCH 155/289] ci : add on push/pull_request paths ruby job (#3833) * ci : add on push/pull_request paths ruby job This commit adds paths to bindings-ruby to only build if changes where made to bindings/ruby or to include/whisper.h. * ci : add additional paths [no ci] --- .github/workflows/bindings-ruby.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index c3f158e26e4..0c31701a2a3 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -4,8 +4,19 @@ on: push: branches: - master + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h + pull_request: types: [opened, synchronize, reopened] + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h jobs: ubuntu-22: From e47a3eeb04176d33630a0a3042caf3b64dc644ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 14:53:34 +0200 Subject: [PATCH 156/289] ci : fix include paths for bindings-go job [no ci] (#3835) --- .github/workflows/bindings-go.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 83473e4636a..44381a4b411 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -3,11 +3,11 @@ on: push: paths: - bindings/go/** - - whisper.h + - include/whisper.h pull_request: paths: - bindings/go/** - - whisper.h + - include/whisper.h jobs: ubuntu-22: From c932729a304f7d9eb5354afa38624cfa86a780cf Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 18:06:04 +0200 Subject: [PATCH 157/289] ci : add ignore for bindings/{ruby, go} in build.yml [no ci] (#3837) This commit adds an ignore for bindings-ruby and bindings-go in build.yml as these are handled by separate .yml file (separate jobs) and don't need to trigger a full CI build. --- .github/workflows/build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aaaa8fe5826..e855ef7cf87 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,6 +28,9 @@ on: pull_request: types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml workflow_dispatch: inputs: create_release: From 205ee5a1898d7a52167a6064b166bd890b06ac6e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 25 May 2026 21:12:10 +0800 Subject: [PATCH 158/289] CUDA: add fast walsh-hadamard transform (llama/23615) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: add fast walsh-hadamard transform * review: add unrolls + change size_t -> int * warp size 64 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fwht.cu | 108 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fwht.cuh | 3 + ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++ 3 files changed, 119 insertions(+) create mode 100644 ggml/src/ggml-cuda/fwht.cu create mode 100644 ggml/src/ggml-cuda/fwht.cuh diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu new file mode 100644 index 00000000000..74e94d8442b --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cu @@ -0,0 +1,108 @@ +#include "common.cuh" +#include "fwht.cuh" + +template +__launch_bounds__(4*ggml_cuda_get_physical_warp_size(), 1) +__global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, const float scale) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const int64_t r = (int64_t) blockIdx.x * blockDim.y + threadIdx.y; + + if (r >= n_rows) { + return; + } + + src += r * N; + dst += r * N; + + static constexpr int el_w = N / warp_size; + float reg[el_w]; + const int lane = threadIdx.x; + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + reg[i] = src[i * warp_size + lane] * scale; + } + +#pragma unroll + for (int h = 1; h < warp_size; h *= 2) { +#pragma unroll + for (int j = 0; j < el_w; j++) { + const float val = reg[j]; + const float val2 = __shfl_xor_sync(0xFFFFFFFF, val, h, warp_size); + + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + +#pragma unroll + for (int h = warp_size; h < N; h *= 2) { + const int step = h / warp_size; +#pragma unroll + for (int j = 0; j < el_w; j += 2 * step) { +#pragma unroll + for (int k = 0; k < step; k++) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + dst[i * warp_size + lane] = reg[i]; + } +} + +void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int n = src->ne[0]; + const int64_t rows = ggml_nrows(src); + + const float * src_d = (const float *) src->data; + float * dst_d = (float *) dst->data; + + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + GGML_ASSERT(n % warp_size == 0); + const int rows_per_block = 4; + + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + + cudaStream_t stream = ctx.stream(); + dim3 grid_dims(num_blocks, 1, 1); + dim3 block_dims(warp_size, rows_per_block, 1); + const ggml_cuda_kernel_launch_params launch_params = + ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); + + const float scale = 1 / sqrtf(n); + + switch (n) { + case 64: + { + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 128: + { + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 256: + { + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 512: + { + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + break; + } + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh new file mode 100644 index 00000000000..fa4c30477a7 --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e25be3592fd..1bb09ac80ee 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -24,6 +24,7 @@ #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/fwht.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmf.cuh" @@ -2594,6 +2595,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD) { + GGML_ASSERT(!split); + ggml_cuda_op_fwht(ctx, src1, dst); + return; + } + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) From 1c477d4056c8c424d45b8ddce1814598600c0d79 Mon Sep 17 00:00:00 2001 From: forforever73 <63285796+forforever73@users.noreply.github.com> Date: Tue, 26 May 2026 02:05:16 +0800 Subject: [PATCH 159/289] metal : add apple device id (llama/23566) Co-authored-by: lvyichen --- ggml/src/ggml-metal/ggml-metal-device.h | 26 ++++++++++++++ ggml/src/ggml-metal/ggml-metal-device.m | 46 +++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1f212a92f98..4a3ebb5569d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -215,6 +215,30 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // device // +enum ggml_metal_device_id { + GGML_METAL_DEVICE_GENERIC = 0, + + GGML_METAL_DEVICE_M1, + GGML_METAL_DEVICE_M1_PRO, + GGML_METAL_DEVICE_M1_MAX, + GGML_METAL_DEVICE_M1_ULTRA, + GGML_METAL_DEVICE_M2, + GGML_METAL_DEVICE_M2_PRO, + GGML_METAL_DEVICE_M2_MAX, + GGML_METAL_DEVICE_M2_ULTRA, + GGML_METAL_DEVICE_M3, + GGML_METAL_DEVICE_M3_PRO, + GGML_METAL_DEVICE_M3_MAX, + GGML_METAL_DEVICE_M3_ULTRA, + GGML_METAL_DEVICE_M4, + GGML_METAL_DEVICE_M4_PRO, + GGML_METAL_DEVICE_M4_MAX, + GGML_METAL_DEVICE_M5, + GGML_METAL_DEVICE_M5_PRO, + GGML_METAL_DEVICE_M5_MAX, + GGML_METAL_DEVICE_M5_ULTRA, +}; + struct ggml_metal_device_props { int device; char name[128]; @@ -234,6 +258,8 @@ struct ggml_metal_device_props { bool supports_gpu_family_apple7; + enum ggml_metal_device_id device_id; + int op_offload_min_batch_size; }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 780dfe81bb3..885344ec670 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -628,6 +628,50 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } +static enum ggml_metal_device_id ggml_metal_device_id_parse(const char * name) { + if (!name) { + return GGML_METAL_DEVICE_GENERIC; + } + + static const char prefix[] = "Apple "; + if (strncmp(name, prefix, sizeof(prefix) - 1) != 0) { + return GGML_METAL_DEVICE_GENERIC; + } + const char * suffix = name + sizeof(prefix) - 1; + + static const struct { + const char * name; + enum ggml_metal_device_id id; + } table[] = { + {"M1", GGML_METAL_DEVICE_M1}, + {"M1 Pro", GGML_METAL_DEVICE_M1_PRO}, + {"M1 Max", GGML_METAL_DEVICE_M1_MAX}, + {"M1 Ultra", GGML_METAL_DEVICE_M1_ULTRA}, + {"M2", GGML_METAL_DEVICE_M2}, + {"M2 Pro", GGML_METAL_DEVICE_M2_PRO}, + {"M2 Max", GGML_METAL_DEVICE_M2_MAX}, + {"M2 Ultra", GGML_METAL_DEVICE_M2_ULTRA}, + {"M3", GGML_METAL_DEVICE_M3}, + {"M3 Pro", GGML_METAL_DEVICE_M3_PRO}, + {"M3 Max", GGML_METAL_DEVICE_M3_MAX}, + {"M3 Ultra", GGML_METAL_DEVICE_M3_ULTRA}, + {"M4", GGML_METAL_DEVICE_M4}, + {"M4 Pro", GGML_METAL_DEVICE_M4_PRO}, + {"M4 Max", GGML_METAL_DEVICE_M4_MAX}, + {"M5", GGML_METAL_DEVICE_M5}, + {"M5 Pro", GGML_METAL_DEVICE_M5_PRO}, + {"M5 Max", GGML_METAL_DEVICE_M5_MAX}, + {"M5 Ultra", GGML_METAL_DEVICE_M5_ULTRA}, + }; + + for (size_t i = 0; i < sizeof(table)/sizeof(table[0]); ++i) { + if (strcmp(suffix, table[i].name) == 0) { + return table[i].id; + } + } + return GGML_METAL_DEVICE_GENERIC; +} + ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); @@ -795,6 +839,8 @@ ggml_metal_device_t ggml_metal_device_init(int device) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.device_id = ggml_metal_device_id_parse([[dev->mtl_device name] UTF8String]); + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; From 2307712d32a17becc38f6efb8154be5196aa8f87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 26 May 2026 05:05:51 +0200 Subject: [PATCH 160/289] CUDA: missing PDL sync for FWHT, better fallback (llama/23690) --- ggml/src/ggml-cuda/fwht.cu | 35 +++++++++++++-------------------- ggml/src/ggml-cuda/fwht.cuh | 3 ++- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +--- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu index 74e94d8442b..184dc254c72 100644 --- a/ggml/src/ggml-cuda/fwht.cu +++ b/ggml/src/ggml-cuda/fwht.cu @@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, float reg[el_w]; const int lane = threadIdx.x; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < el_w; ++i) { reg[i] = src[i * warp_size + lane] * scale; @@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, } } -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src, dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_is_contiguous(dst)); + if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) { + return false; + } const int n = src->ne[0]; const int64_t rows = ggml_nrows(src); @@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, float * dst_d = (float *) dst->data; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; - GGML_ASSERT(n % warp_size == 0); const int rows_per_block = 4; const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; @@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, switch (n) { case 64: - { - ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + return true; case 128: - { - ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + return true; case 256: - { - ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + return true; case 512: - { - ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + return true; default: - GGML_ABORT("fatal error"); + return false; } } diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh index fa4c30477a7..cf3df94cafa 100644 --- a/ggml/src/ggml-cuda/fwht.cuh +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -1,3 +1,4 @@ #include "common.cuh" -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); +// Returns whether the Fast Walsh-Hadamard transform could be used. +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1bb09ac80ee..23d1c069248 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; const int32_t hint = ggml_get_op_params_i32(dst, 1); - if (hint == GGML_HINT_SRC0_IS_HADAMARD) { - GGML_ASSERT(!split); - ggml_cuda_op_fwht(ctx, src1, dst); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) { return; } From bc77933c2de8c5d104bfc234e3358217265fe147 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Mon, 25 May 2026 20:32:49 -0700 Subject: [PATCH 161/289] Check batch_compute_passes before sending passes when not doing GPU profiling (llama/23457) * Only run webgpu CI on my fork * Add webgpu only workflow * refactor batch_compute_passes to a per-thread variable, and submit individual passes when it is set to false and no GPU profiling is enabled * restore build.yml --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 35 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 921c12b41ac..1561a4e30c6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -259,6 +259,7 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_host_error_buf; wgpu::CommandEncoder active_command_encoder; wgpu::ComputePassEncoder active_compute_pass; + bool batch_compute_passes = true; size_t memset_bytes_per_thread; @@ -590,9 +591,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & } #else for (size_t i = 0; i < dispatches.size(); i++) { - ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); - ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); - ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + if (ctx->batch_compute_passes) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, + 1); + } else { + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(); + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + } } #endif @@ -1956,10 +1966,10 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_entries; if (use_vec_reduce) { const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; - const uint32_t reduce_wg_size = - std::max(reduce_sg_size, (uint32_t) std::min( - (uint64_t) nwg * reduce_sg_size, - ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_wg_size = std::max( + reduce_sg_size, + (uint32_t) std::min((uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3110,18 +3120,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_batched_kernels = 0; uint32_t num_inflight_batches = 0; bool contains_set_rows = false; - bool batch_compute_passes = true; int num_encoded_ops = 1; int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; - batch_compute_passes = false; std::vector profile_pipeline_names; #endif ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } @@ -3148,7 +3156,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str // reset state for next batch ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } ctx->param_arena.reset(); @@ -3548,8 +3556,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const uint32_t kv_tile = decisions.kv_tile; const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; } @@ -3839,6 +3847,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); #ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_ctx->batch_compute_passes = false; ggml_webgpu_create_buffer( webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); From 00a5110b1945a6144b7f5f766da98d03c84bcec4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 26 May 2026 12:42:49 +0900 Subject: [PATCH 162/289] ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K and clean up legacy MUL_MAT pipeline (llama/23594) * ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K * Fix to editorconfig checking pass * Remove mul-mat-legacy pipeline * Fix to use vendor name as is and add dot_product/vendor to shader_lib_ctx --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 231 +++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 165 ++-- .../wgsl-shaders/common_decls.tmpl | 3 +- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 747 ------------------ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 22 +- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 1 - .../wgsl-shaders/mul_mat_vec_q_acc.tmpl | 303 +++++++ .../ggml-webgpu/wgsl-shaders/quantize_q8.wgsl | 173 ++++ 8 files changed, 714 insertions(+), 931 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 4c4eda1cbe5..60e98a60741 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -52,7 +52,7 @@ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 -// default size for legacy matrix multiplication +// default size for reg-tile matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost @@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_k = 0; uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; + std::string vendor; }; struct webgpu_pipeline { @@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( /** Matrix Multiplication **/ -struct ggml_webgpu_legacy_mul_mat_pipeline_key { - ggml_type src0_type; - ggml_type src1_type; - - bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; - } -}; - -struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { - size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.src0_type); - ggml_webgpu_hash_combine(seed, key.src1_type); - return seed; - } -}; - struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_mmvq == other.use_mmvq; } }; @@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } }; @@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t vec_size; }; +struct ggml_webgpu_quantize_q8_pipeline_key { + ggml_type src0_type; + + bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; } +}; + +struct ggml_webgpu_quantize_q8_pipeline_key_hash { + size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + return seed; + } +}; + struct ggml_webgpu_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; @@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash { } }; +/** MMVQ **/ + +inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, + const ggml_tensor * src1, + bool supports_dot_product, + const std::string & vendor) { + if (src1->ne[1] == 1) { + bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; + if (supports_dp4a && supports_dot_product) { + switch (src1->type) { + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + return src0->ne[0] % 4 == 0; + default: + break; + } + break; + default: + break; + } + } + } + return false; +} + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; - std::unordered_map - mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map + quantize_q8_pipelines; std::unordered_map mul_mat_id_gather_pipelines; // key is fixed std::unordered_map mul_mat_id_pipelines; // src0_type/src1_type @@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib { return pad_pipelines[key]; } + webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_quantize_q8_pipeline_key key = {}; + key.src0_type = context.src0->type; + + auto it = quantize_q8_pipelines.find(key); + if (it != quantize_q8_pipelines.end()) { + return it->second; + } + const char * shader_src = wgsl_quantize_q8; + std::vector defines; + std::string variant = "quantize_q8"; + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + + defines.push_back("SRC1_INNER_TYPE=f32"); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("Q8_1_T"); + + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + quantize_q8_pipelines[key] = pipeline; + return quantize_q8_pipelines[key]; + } + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; @@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.use_mmvq = + ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); switch (context.src0->type) { + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + if (key.use_mmvq) { + defines.push_back("LEGACY_QUANTS"); + } + break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + if (key.use_mmvq) { + defines.push_back("K_QUANTS"); + } + break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_S: @@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } + if (key.use_mmvq) { + defines.push_back("MMVQ"); + defines.push_back("Q8_1_T"); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); @@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib { return mul_mat_fast_pipelines[key]; } - webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - - auto it = mul_mat_legacy_pipelines.find(key); - if (it != mul_mat_legacy_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "mul_mat"; - - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); - } - - const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); - const char * src0_name = src0_traits->type_name; - - switch (context.src0->type) { - case GGML_TYPE_F32: - defines.push_back("SRC0_TYPE=f32"); - defines.push_back("FLOAT"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC0_TYPE=f16"); - defines.push_back("FLOAT"); - variant += "_f16"; - break; - default: - { - std::string type_upper = src0_name; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - switch (context.src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - { - // Quantized types using u32 buffers for portability. - defines.push_back("SRC0_TYPE=u32"); - defines.push_back("U32_DEQUANT_HELPERS"); - break; - } - default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } - } - - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - defines.push_back(type_upper); - defines.push_back(type_upper + "_SCALE_MIN"); - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); - - variant += std::string("_") + src0_name; - break; - } - } - - auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); - - auto decisions = std::make_shared(); - decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; - - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - mul_mat_legacy_pipelines[key] = pipeline; - return mul_mat_legacy_pipelines[key]; - } - webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = mul_mat_id_gather_pipelines.find(1); if (it != mul_mat_id_gather_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1561a4e30c6..f113da909ce 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -181,6 +181,7 @@ struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroups = false; bool supports_subgroup_matrix = false; + bool supports_dot_product = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -210,6 +211,8 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + std::string vendor; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -1394,6 +1397,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + std::vector & dispatches) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx); + + // quantize_q8 pipeline + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t q8_src1_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t q8_src1_binding_size = + ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); + + std::vector q8_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector q8_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size) + }; + + auto q8_decisions = static_cast(qq8_pipeline.context.get()); + + uint32_t q8_wg_size = q8_decisions->wg_size; + uint32_t q8_wg_x = 1; + uint32_t q8_wg_y = 1; + const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; + const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); + + dispatches.push_back({ + qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y } + }); +} + static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1401,47 +1456,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); - // Determine if we should use fast path - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q1_0: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_MXFP4: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } + // use MMVQ path for mat-vec + bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, + ctx->global_ctx->vendor); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; @@ -1456,16 +1473,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product; + shader_lib_ctx.vendor = ctx->global_ctx->vendor; // Get or create pipeline - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; + std::vector dispatches; - if (use_fast && is_vec) { + if (is_vec) { + if (use_mmvq) { + ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); + } pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); - } else if (use_fast) { - pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } else { - pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } // Build params @@ -1489,25 +1510,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Build bind group entries - std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; + std::vector entries = {}; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + if (use_mmvq) { + auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1]; + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset, + mmvq_qq8_entry.size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); // Calculate workgroup dimensions uint32_t wg_x = 1; uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (use_fast && is_vec) { + if (is_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - } else if (use_fast) { + } else { auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations @@ -1528,15 +1555,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - - } else { // legacy - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_size = decisions->wg_size; - uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + dispatches.push_back({ + pipeline, std::move(params), std::move(entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, @@ -3590,6 +3615,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + bool use_mmvq = + ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, + ctx->webgpu_global_ctx->vendor); + if (use_mmvq) { + const size_t q8_src1_size = + src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + res = ROUNDUP_POW2(res + q8_src1_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; case GGML_OP_MUL_MAT_ID: { const ggml_tensor * src0 = tensor->src[0]; @@ -3715,12 +3756,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetInfo(&info); ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); + ctx->webgpu_global_ctx->vendor = info.vendor; wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + // for dot4I8packed + ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature( + wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct); bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 372ea79bf9d..758efa17d77 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -95,11 +95,10 @@ struct q5_1 { }; #endif - #ifdef Q8_1_T struct q8_1 { d: f16, - m: f16, + s: f16, // d * sum(qs[i]) qs: array }; #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index fcbefdeb802..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,747 +0,0 @@ -enable f16; - -#define DECLARE_BYTE_LOADERS_SRC0 -#include "common_decls.tmpl" - - -#ifdef FLOAT -const BLOCK_SIZE = 1u; - -#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) -const BLOCK_SIZE = 32u; - -#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) -const BLOCK_SIZE = 256u; -#endif - -#ifdef FLOAT -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#endif - -#ifdef Q4_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q4_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - let qh_packed = load_u32_at_src0(block_byte_base + 2); - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q8_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q8_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q2_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} -#endif - -#ifdef Q3_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - - // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 108); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - // Bytes 96-107: 12 bytes of scales (3 u32s) - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 96); - scale_vals[1] = load_u32_at_src0(block_byte_base + 100); - scale_vals[2] = load_u32_at_src0(block_byte_base + 104); - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // Bytes 0-31: 32 bytes of hmask (8 u32s) - var hmask_vals: array; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 32-95: 64 bytes of qs (16 u32s) - var qs_vals: array; - for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q4_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef Q5_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q6_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - - // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 208); - - // Bytes 0-127: 128 bytes of ql (32 u32s) - var ql_vals: array; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 128-191: 64 bytes of qh (16 u32s) - var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); - } - - // Bytes 192-207: 16 bytes of scales (4 u32s) - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} -#endif - -#ifdef IQ2_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0_offset = block_byte_base + 2 + ib * 2; - let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at_src0(aux0_offset); - let aux1 = load_u32_at_src0(aux1_offset); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var scale_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qs_vals : array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - - var qh_vals: array; - qh_vals[0] = load_u32_at_src0(block_byte_base + 66); - qh_vals[1] = load_u32_at_src0(block_byte_base + 70); - - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 74); - scale_vals[1] = load_u32_at_src0(block_byte_base + 78); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ3_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at_src0(sc_sign_offset); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} -#endif - -#ifdef IQ3_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qh_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sign_vals: array; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); - } - - var scale_vals = load_u32_at_src0(block_byte_base + 106); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#endif - -#ifdef IQ1_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; - let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - - -#ifdef IQ1_M -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ4_NL -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} -#endif - -#ifdef IQ4_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = unpack2x16float(block.d_scales_h)[0]; - let scales_h = block.d_scales_h >> 16; - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} -#endif - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let global_idx = wg_linear * 256u + local_id.x; - - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_idx >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_idx / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_idx % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a194cf40468..f0a7fbd059a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -3,10 +3,18 @@ enable subgroups; #endif enable f16; +#ifdef MMVQ +requires packed_4x8_integer_dot_product; +#endif + #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef MMVQ +#include "mul_mat_vec_q_acc.tmpl" +#else #include "mul_mat_vec_acc.tmpl" +#endif struct MulMatParams { offset_src0: u32, @@ -28,9 +36,14 @@ struct MulMatParams { }; @group(0) @binding(0) var src0: array; + +#ifdef MMVQ +@group(0) @binding(1) var src1q: array; +#else @group(0) @binding(1) var src1: array; -@group(0) @binding(2) var dst: array; +#endif +@group(0) @binding(2) var dst: array; // "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; @@ -75,10 +88,15 @@ fn main( let src12_idx = dst2_idx; let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; +#ifdef MMVQ + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); +#else + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); +#endif #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 711c7e829d8..08753b9d643 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src } #endif - #ifdef MUL_ACC_Q3_K #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl new file mode 100644 index 00000000000..3ef2f77ebe0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -0,0 +1,303 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +#ifdef LEGACY_QUANTS +#define BLOCK_SIZE 32 +#define THREADS_PER_BLOCK 4 +#elif K_QUANTS +#define BLOCK_SIZE 256 +#define THREADS_PER_BLOCK 16 +#endif + +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +#define Q8_BLOCK_SIZE 32 + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE_BYTES 18 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE_BYTES 20 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base)), + f32(load_f16_at_src0(block_byte_base + 2u)) + ); +} +fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE_BYTES 34 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + return vec2( + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)), + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1)) + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id * 2u], + src1q[block].qs[inner_id * 2u + 1], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[block].d); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds); +} +#endif + +#ifdef LEGACY_QUANTS +fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { + var row_sum = 0; + let a_repacked = repack_a(a_byte_base, b_inner_id); + + row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); + row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); + + return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); +} + +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let b_inner_id = thread_id % THREADS_PER_BLOCK; + let b_block_idx = src1q_idx_base + block; + + let b_repacked = repack_b_qs(b_block_idx, b_inner_id); + let b_ds = repack_b_dm(b_block_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE_BYTES 84 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let ih2 = tid / 8u; + let phase = tid % 2u; + let iq4_idx = 2u * ih2 + phase; + let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx; + let qs_shift = tid & 6u; + return vec4( + (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[q8_block_idx].d); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 80u)), + f32(load_f16_at_src0(block_byte_base + 82u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let scale_byte = block_byte_base + tid; + let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); + return vec2(f32(scale & 0xFu), f32(scale >> 4u)); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE_BYTES 144 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let iq4 = tid / 4u; + let phase = tid % 2u; + let nibble = (tid >> 1u) % 2u; + let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase; + let qs_shift = 4u * nibble; + return vec4( + (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[q8_block_idx].d), + f32(src1q[q8_block_idx].s), + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 0u)), + f32(load_f16_at_src0(block_byte_base + 2u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let sc_m_idx = tid / 2u; + let scales_byte_base = block_byte_base + 4u; + let scales0_3 = load_u32_at_src0_aligned(scales_byte_base); + let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u); + let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u); + + let byte_idx = sc_m_idx & 3u; + let is_high = sc_m_idx >= 4u; + + let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu; + let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u); + let scale = f32(select(sc_low, sc_high, is_high)); + + let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu; + let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u); + let min_val = f32(select(mn_low, mn_high, is_high)); + + return vec2(scale, min_val); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +} +#endif + +#ifdef K_QUANTS +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + + for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { + let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl new file mode 100644 index 00000000000..b3f1fa04b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -0,0 +1,173 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +requires packed_4x8_integer_dot_product; + +#include "common_decls.tmpl" + +struct Params { + offset_src1: u32, + stride_12: u32, + stride_13: u32, + ne0: u32, + ne2: u32, + ne3: u32, +}; + +#define SRC1_TYPE vec4 + +@group(0) @binding(0) var src1: array; +@group(0) @binding(1) var src1q: array; + +@group(0) @binding(2) var params: Params; + +#ifdef USE_SUBGROUP_REDUCTION +fn cluster_max_8(v: f32) -> f32 { + var r = v; + r = max(r, subgroupShuffleXor(r, 1u)); + r = max(r, subgroupShuffleXor(r, 2u)); + r = max(r, subgroupShuffleXor(r, 4u)); + return r; +} + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) +fn cluster_add_i4x8(v: i32) -> i32 { + var r= v; + r += subgroupShuffleXor(r, 1u); + r += subgroupShuffleXor(r, 2u); + r += subgroupShuffleXor(r, 4u); + return r; +} +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION +#define CLUSTER_SIZE 8 + +var partial_amaxs: array, WG_SIZE / CLUSTER_SIZE>; +var partial_sums: array, WG_SIZE / CLUSTER_SIZE>; +#endif + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + let thread_id = local_id.x; + let num_vec4 = params.ne0 / 4u; + + let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne2 * params.ne3; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + if (wg_linear >= total_batches) { + return; + } + + let src13_idx = wg_linear / (params.ne2 * wg_per_vec); + let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; + let src11_wg_idx = wg_linear % wg_per_vec; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let src1_idx_vec4_base = src1_idx_base / 4u; + + let blocks_per_row = params.ne0 / 32u; + let blocks_per_wg = (WG_SIZE * 4u) / 32u; + let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; + let qs_idx = thread_id % 8u; + + // reduction + var q4 = vec4(0.0); + var q4_quants = 0u; + var thread_amax = 0.0; + + let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; + let is_valid = src11_vec4_idx < num_vec4; + +#ifdef USE_SUBGROUP_REDUCTION + + var d = 0.0; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3])); + } + + d = cluster_max_8(thread_amax) / 127.0; + + if (is_valid) { + let id = select(0.0, 1.0 / d, d > 0.0); + q4_quants = pack4xI8(vec4(round(q4 * id))); + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + src1q[src1q_idx].qs[qs_idx] = q4_quants; + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u); + let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum))); + + if (is_valid) { + if (qs_idx == 0u) { + src1q[src1q_idx].s = s; + } + } +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION + + var d = 0.0; + let cluster_id = thread_id / 8u; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3])); + partial_amaxs[cluster_id][qs_idx] = thread_amax; + } + + workgroupBarrier(); + + if (is_valid) { + let amax = max( + max( + max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])), + max( + max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7])) + ); + + d = amax / 127.0; + let id = select(0.0f, 1.0f / d, d > 0.0f); + + q4_quants = pack4xI8(vec4(round(q4 * id))); + src1q[src1q_idx].qs[qs_idx] = q4_quants; + + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + + partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u); + + workgroupBarrier(); + + if (is_valid) { + if (qs_idx == 0u) { + let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3] + + partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]); + src1q[src1q_idx].s = f16(s); + } + } + +#endif +#endif + +} From 049f0af3398f67acc547844dec0d14310ab2bbb5 Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Tue, 26 May 2026 13:59:00 +0900 Subject: [PATCH 163/289] SYCL: implement ggml_sycl_pool_vmm (llama/22862) * SYCL: implement ggml_sycl_pool_vmm * Add an option to bypass VMM with GGML_SYCL_DISABLE_VMM * Clean up debugging logging * document GGML_SYCL_DISABLE_VMM * Multi-stream MoE optimization * Revert "Multi-stream MoE optimization" This reverts commit 938929c3f13a562ec67c59e87cc5d38595444cce. * Update common.hpp Co-authored-by: Neo Zhang * Flip GGML_SYCL_DISABLE_VMM to GGML_SYCL_ENABLE_VMM * add logging for GGML_SYCL_ENABLE_VMM when extension is not available (SYCL_EXT_ONEAPI_VIRTUAL_MEM macro) * Apply suggestions from code review Co-authored-by: Alexey Kopytko * Apply suggestion from @sanmai * Apply suggestion from @sanmai --------- Co-authored-by: Neo Zhang --- ggml/src/ggml-sycl/common.hpp | 3 + ggml/src/ggml-sycl/ggml-sycl.cpp | 171 +++++++++++++++++++++++++++++-- 2 files changed, 163 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 6d19538215e..31e26ff48e4 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -224,6 +224,7 @@ struct sycl_device_info { int max_wg_per_cu; // max work groups per compute unit - refer to // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; sycl_hw_info hw_info; optimize_feature opt_feature; @@ -244,6 +245,8 @@ struct ggml_sycl_device_info { const ggml_sycl_device_info & ggml_sycl_info(); +static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128; + struct ggml_sycl_pool { virtual ~ggml_sycl_pool() = default; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b3fbb621196..729a88b4db8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,11 @@ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif +#if SYCL_EXT_ONEAPI_VIRTUAL_MEM +# include +# include +# define GGML_SYCL_USE_VMM +#endif #include #include "ggml.h" @@ -70,6 +76,7 @@ int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_enable_vmm = 1; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; int g_ggml_sycl_use_async_mem_op_requested = 1; @@ -96,13 +103,30 @@ static ggml_sycl_device_info ggml_sycl_init() { // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); // #endif for (int i = 0; i < info.device_count; ++i) { - info.devices[i].vmm = 0; dpct::device_info prop; auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); +#if !defined(GGML_SYCL_USE_VMM) + info.devices[i].vmm = 0; +#else + info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem); + if (info.devices[i].vmm) { + // NB: SYCL's get_mem_granularity always returns the _minimum_ granularity, + // but the L0 API requires a larger page size for allocs above 2 MiB and + // rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic]. + // Here we clamp it to 2 MiB for simplicity, but other devices may require + // calling zeVirtualMemQueryPageSize or yet unexposed public API. + const size_t physical_page = 2ull << 20; // 2 MiB + info.devices[i].vmm_granularity = std::max( + sycl::ext::oneapi::experimental::get_mem_granularity( + device, sycl::context(device)), + physical_page); + } +#endif + info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); @@ -234,6 +258,7 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); @@ -275,6 +300,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); #endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -293,6 +323,11 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); @@ -754,7 +789,7 @@ catch (sycl::exception const &exc) { } static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1177,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg } static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1462,6 +1497,121 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +// pool with virtual memory management +#if defined(GGML_SYCL_USE_VMM) +struct ggml_sycl_pool_vmm : public ggml_sycl_pool { + static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + + int device; + sycl::context ctx; + sycl::device dev; + + uintptr_t pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; + size_t granularity; + + // physical_mem owns the commits (unlike cuMemMap) + struct mapping { + sycl::ext::oneapi::experimental::physical_mem phys; + void * map_ptr; + }; + std::vector mappings; + + explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) : + device(device_), + ctx(qptr_->get_context()), + dev(qptr_->get_device()), + granularity(ggml_sycl_info().devices[device_].vmm_granularity) { + } + + ~ggml_sycl_pool_vmm() { + if (pool_addr == 0) { + return; + } + + // Per spec, unmap must (a) match the exact (ptr, size) of an earlier + // physical_mem::map() call and (b) precede destruction of the + // physical_mem objects (their dtors won't unmap). + for (auto & m : mappings) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap( + m.map_ptr, m.phys.size(), ctx))); + } + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem( + pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + void * alloc(size_t size, size_t * actual_size) override { + // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types + size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT); + + size_t avail = pool_size - pool_used; + + if (size > avail) { + // round up to the next multiple of the granularity + size_t reserve_size = GGML_PAD(size - avail, granularity); + + GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE); + + // allocate more physical memory + std::optional phys; + SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size))); + + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + SYCL_CHECK(CHECK_TRY_ERROR( + pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem( + SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + // map at the end of the pool + void * map_ptr = nullptr; + SYCL_CHECK(CHECK_TRY_ERROR( + map_ptr = phys->map(pool_addr + pool_size, reserve_size, + sycl::ext::oneapi::experimental::address_access_mode::read_write))); + + // stash these so we could unmap this exact range in dtor + mappings.push_back({ + std::move(*phys), + map_ptr, + }); + + // add to the pool + pool_size += reserve_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + device, (unsigned long long) (pool_size/1024/1024), + (unsigned long long) (reserve_size/1024/1024)); +#endif + } + + GGML_ASSERT(pool_addr != 0); + + void * ptr = reinterpret_cast(pool_addr + pool_used); + *actual_size = size; + pool_used += size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + return ptr; + } + + void free(void * ptr, size_t size) override { +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + pool_used -= size; + + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == reinterpret_cast(pool_addr + pool_used)); + } +}; +#endif // defined(GGML_SYCL_USE_VMM) + struct ggml_sycl_pool_host : public ggml_sycl_pool { queue_ptr qptr; int device; @@ -1542,20 +1692,19 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(que } std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { - // TBD: NO VMM support - // if (ggml_sycl_info().devices[device].vmm) { - // return std::unique_ptr(new ggml_sycl_pool_vmm(device)); - // } - return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); +#if defined(GGML_SYCL_USE_VMM) + if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) { + return std::unique_ptr(new ggml_sycl_pool_vmm(qptr, device)); + } +#endif // defined(GGML_SYCL_USE_VMM) + return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); } + std::unique_ptr ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { return std::unique_ptr(new ggml_sycl_fattn_kv_buffers(qptr, device)); } -// TBD pool with virtual memory management -// struct ggml_sycl_pool_vmm : public ggml_sycl_pool - /// kernels typedef void (*ggml_sycl_op_mul_mat_t)( ggml_backend_sycl_context & ctx, From f8df28d3319ecf97d8bbc27cdc22bcff8f1dcbe0 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 26 May 2026 06:20:05 -0700 Subject: [PATCH 164/289] hexagon: add support for CONCAT op (llama/23648) * hexagon: add support for CONCAT with optimized concat_2d_transposed qwen3.5 models are quite heavy on the CONCAT with large and transposed src1. * hex-concat: use fastdiv in generic version * hex-concat: make checks for transposed a bit more readable * hex-concat: reoder dma ops for better pipelining * hex-cont/cpy: optimize CPY and CONT ops The primary change is to avoid scalar divs in the inner loops. We were calling hvx_copy_uu(... type_size) where type_size is non a constexpr. This causes runtime divs by that value which is normally just 4 or 2 (f32/f16). * hex-get-rows: optimize GET_ROWS for large rows We now use DMA for larger rows and also split them into chunks to improve perf for Qwen3.5 and other models that do lots of GET_ROWS with huge (2MB+ rows). Also bump the DMA queue depth now that we can take advantage of it. * hex-concat: unroll the inner loops of concat_2d * hex-concat: more updates to concat_2d to improve perf a bit further * hex-cpy: fixed n_rows per thread checks in the copy ops * hmx-fa: fix alignment issues while computing dma sizes * hex-set-rows: add early returns for idle threads * hvx-rope: minor optimization to replace loops with fastdiv logic * hex-rope: replace scalar tail processing with HVX * hex-rope: optimize rope cache init with HVX Add hvx-utils sin/cos helpers that use an aprox method (similar to rsqrt, inverse, etc) Use the helpers to optimize ROPE. --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 24 ++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/concat-ops.c | 275 ++++++++++++++++ ggml/src/ggml-hexagon/htp/cpy-ops.c | 310 ++++++++++-------- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 120 ++++++- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 21 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-sin-cos.h | 90 +++++ ggml/src/ggml-hexagon/htp/hvx-utils.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 6 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 240 ++++++++++---- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 6 + ggml/src/ggml-hexagon/htp/unary-ops.c | 2 +- 14 files changed, 868 insertions(+), 231 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/concat-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-sin-cos.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9db99cb0f3a..1c8ecc197e9 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2874,6 +2874,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_CONCAT: return HTP_OP_CONCAT; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; case GGML_OP_SQRT: return HTP_OP_SQRT; @@ -3286,6 +3287,25 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + int dim = ((const int32_t *) op->op_params)[0]; + if (dim < 0 || dim >= GGML_MAX_DIMS) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const struct ggml_tensor * src = op->src[i]; + if (!src) { + continue; + } + if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) { + return false; + } + } + + return true; +} + static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * dst = op; @@ -3434,6 +3454,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_CONCAT: + supp = ggml_hexagon_supported_concat(sess, op); + break; + case GGML_OP_FILL: supp = ggml_hexagon_supported_fill(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 36f923243cd..33d67dda9cc 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -35,6 +35,7 @@ add_library(${HTP_LIB} SHARED ssm-conv.c cumsum-ops.c fill-ops.c + concat-ops.c diag-ops.c solve-tri-ops.c gated-delta-net-ops.c diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c new file mode 100644 index 00000000000..61580f2c08f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -0,0 +1,275 @@ +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hexagon_types.h" +#include "hexagon_protos.h" +#include "hvx_hexagon_protos.h" +#include "hex-dma.h" +#include "vtcm-utils.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" +#include + +struct htp_concat_context { + struct htp_ops_context * octx; + uint32_t dim; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne0; + struct fastdiv_values div_ne1; + struct fastdiv_values div_ne2; +}; + +static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 32; + const uint32_t spad1_stride = block_i * sizeof(float); + + int32_t offsets[32] __attribute__((aligned(128))); + for(int k=0; k<32; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(float); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(float); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 32) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float)); + Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 64; + const uint32_t spad1_stride = block_i * sizeof(__fp16); + + int16_t offsets[64] __attribute__((aligned(128))); + for(int k=0; k<64; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(__fp16); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 64) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16)); + Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_generic(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const int dim = cctx->dim; + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + + const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]}; + const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3]; + const uint32_t chunk_size = (total_elements + nth - 1) / nth; + + const uint32_t start_idx = MIN(ith * chunk_size, total_elements); + const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements); + + // Naive scalar element-wise copy + for (uint32_t idx = start_idx; idx < end_idx; idx++) { + uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0); + uint32_t i0 = idx - idx_div_ne0 * ne[0]; + + uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1); + uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1]; + + uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2); + uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2]; + uint32_t i3 = idx_div_ne012; + + uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0]; + + uint32_t idx_dim = 0; + if (dim == 0) idx_dim = i0; + else if (dim == 1) idx_dim = i1; + else if (dim == 2) idx_dim = i2; + else if (dim == 3) idx_dim = i3; + + const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1; + + uint32_t s0 = i0; + uint32_t s1 = i1; + uint32_t s2 = i2; + uint32_t s3 = i3; + + if (dim == 0 && src == src1) s0 -= src0->ne[0]; + if (dim == 1 && src == src1) s1 -= src0->ne[1]; + if (dim == 2 && src == src1) s2 -= src0->ne[2]; + if (dim == 3 && src == src1) s3 -= src0->ne[3]; + + uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0]; + + if (type_size == 4) { + *(float*)dst_ptr = *(float*)src_ptr; + } else { + *(__fp16*)dst_ptr = *(__fp16*)src_ptr; + } + } +} + +int op_concat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + int dim = octx->op_params[0]; + + bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1; + + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + bool is_src1_transposed = (src1->nb[0] > src1->nb[1]); + bool is_src0_transposed = (src0->nb[0] > src0->nb[1]); + + uint32_t n_threads = octx->n_threads; + struct htp_concat_context cctx; + cctx.octx = octx; + cctx.dim = dim; + cctx.div_ne0 = init_fastdiv_values(dst->ne[0]); + cctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + cctx.div_ne2 = init_fastdiv_values(dst->ne[2]); + + void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic; + + if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) { + n_threads = MIN(dst->ne[1], n_threads); + if (n_threads < 1) { + n_threads = 1; + } + uint32_t block_i = (type_size == 4) ? 32 : 64; + + cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads); + + // Allocate VTCM + uint32_t spad1_stride = block_i * type_size; + + uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i); + uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN); + + octx->src0_spad.size_per_thread = block_i * spad0_row_bytes; + octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + + if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + if (type_size == 4) { + worker_func = concat_2d_f32_transposed; + } else { + worker_func = concat_2d_f16_transposed; + } + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 5c040a32224..ae507effa51 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -28,158 +28,170 @@ struct htp_copy_context { uint32_t dst_blocks_per_row; uint32_t src0_nrows_per_thread; - - void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; #define cpy_preamble \ const struct htp_tensor *src0 = octx->src[0]; \ const struct htp_tensor *dst = octx->dst; \ \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ const uint32_t nr = ne01; -static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // copy by rows - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - #pragma unroll(2) - for (uint32_t i01 = ir0; i01 < ir1; i01++) { - uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); - hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); - } - } - } +#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + for (uint32_t i03 = 0; i03 < ne03; i03++) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + _Pragma("unroll(4)") \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \ + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \ + hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + } \ } -static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // Fast path: when both src0 and dst are contiguous in memory - // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. - const bool src0_contig = (nb00 == ct->src0_type_size) && - (nb01 == ne00 * nb00) && - (nb02 == ne01 * nb01) && - (nb03 == ne02 * nb02); - const bool dst_contig = (nb0 == ct->dst_type_size) && - (nb1 == ne0 * nb0) && - (nb2 == ne1 * nb1) && - (nb3 == ne2 * nb2); - - if (src0_contig && dst_contig) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; - uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; - uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; - hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); - } - } - return; - } - - // dst counters - int64_t k10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - // number of blocks in a row - const int64_t nk00 = ct->src0_blocks_per_row; - const int64_t nk0 = ct->dst_blocks_per_row; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - k10 += nk00 * ir0; - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t k00 = 0; k00 < nk00; k00++) { - const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, ct->dst_type_size); - - if (++k10 == nk0) { - k10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - k10 += nk00 * (ne01 - ir1); - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } +DEFINE_CPY_SAMESHAPE(f32, float, 4) +DEFINE_CPY_SAMESHAPE(f16, __fp16, 2) + +#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + const bool src0_contig = (nb00 == ELEM_SIZE) && \ + (nb01 == ne00 * nb00) && \ + (nb02 == ne01 * nb01) && \ + (nb03 == ne02 * nb02); \ + const bool dst_contig = (nb0 == ELEM_SIZE) && \ + (nb1 == ne0 * nb0) && \ + (nb2 == ne1 * nb1) && \ + (nb3 == ne2 * nb2); \ + if (src0_contig && dst_contig) { \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \ + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \ + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \ + (ne0 == ne00 * ne01) && (ne1 == ne02) && \ + (nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \ + if (reshape_flat_fast) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + int64_t k10 = 0; \ + int64_t i11 = 0; \ + int64_t i12 = 0; \ + int64_t i13 = 0; \ + const int64_t nk00 = ct->src0_blocks_per_row; \ + const int64_t nk0 = ct->dst_blocks_per_row; \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + k10 += nk00 * ir0; \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + for (int64_t i01 = ir0; i01 < ir1; i01++) { \ + for (int64_t k00 = 0; k00 < nk00; k00++) { \ + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \ + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \ + memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \ + if (++k10 == nk0) { \ + k10 = 0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ + k10 += nk00 * (ne01 - ir1); \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ } -static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +DEFINE_CPY_RESHAPE(f32, float, 4) +DEFINE_CPY_RESHAPE(f16, __fp16, 2) + +static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -195,13 +207,16 @@ static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -217,11 +232,6 @@ static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_work_func(unsigned int n, unsigned int i, void *data) { - struct htp_copy_context *ct = (struct htp_copy_context *) data; - ct->copy(ct, ct->octx, n, i); -} - int op_cpy(struct htp_ops_context * octx) { cpy_preamble; @@ -254,22 +264,32 @@ int op_cpy(struct htp_ops_context * octx) { ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; + worker_callback_t copy_fun; + if (sametype && sameshape) { - ct.copy = cpy_thread_sametype_sameshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_sameshape; + } else { + copy_fun = cpy_thread_f16_sameshape; + } } else if (sameshape) { /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) - ct.copy = cpy_thread_f16_f32_sameshape; + copy_fun = cpy_thread_f16_f32_sameshape; else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) - ct.copy = cpy_thread_f32_f16_sameshape; + copy_fun = cpy_thread_f32_f16_sameshape; else return HTP_STATUS_NO_SUPPORT; } else if (sametype) { - ct.copy = cpy_thread_sametype_reshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_reshape; + } else { + copy_fun = cpy_thread_f16_reshape; + } } else { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); + worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 5a1dc933860..bf7063e9880 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -17,9 +17,13 @@ struct get_rows_context { struct htp_ops_context * octx; - uint32_t src1_nrows_per_thread; + uint32_t tasks_per_thread; + uint32_t total_tasks; + uint32_t chunks_per_row; + uint32_t chunk_size; struct fastdiv_values get_rows_div_ne10; struct fastdiv_values get_rows_div_ne10_ne11; + struct fastdiv_values get_rows_div_chunks_per_row; }; #define get_rows_preamble \ @@ -52,20 +56,23 @@ struct get_rows_context { \ const uint32_t nr = ne10 * ne11 * ne12; -static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { +static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) { struct get_rows_context * grctx = (struct get_rows_context *)data; struct htp_ops_context * octx = grctx->octx; get_rows_preamble; uint64_t qt = HAP_perf_get_qtimer_count(); - // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = grctx->src1_nrows_per_thread; + const uint32_t dr = grctx->tasks_per_thread; const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + dma_queue * dma_queue = octx->ctx->dma[ith]; for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; @@ -73,28 +80,76 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; - uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i01 >= ne01) { - // invalid index, skip for now to avoid crash continue; } const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + + while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) { + dma_queue_pop(dma_queue); + } } + dma_queue_flush(dma_queue); qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); - FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } -int op_get_rows(struct htp_ops_context * octx) { +static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; - const uint32_t n_threads = MIN(nr, octx->n_threads); + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; + const uint32_t ir0 = dr * ith; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); + + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + + const uint32_t chunks_per_row = grctx->chunks_per_row; + const uint32_t chunk_size = grctx->chunk_size; + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row); + const uint32_t chunk_idx = i - row_idx * chunks_per_row; + + const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11); + const uint32_t rem = row_idx - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + continue; + } + + const uint32_t offset = chunk_idx * chunk_size; + if (offset < ne00) { + const uint32_t copy_size = MIN(chunk_size, ne00 - offset); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float); + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size); + } + } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; @@ -112,13 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } + const uint32_t nb00 = octx->src[0]->nb[0]; + const uint32_t nb0 = octx->dst->nb[0]; + + const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float)); + const bool use_dma = can_use_dma && (ne00 >= 2048); + struct get_rows_context grctx; grctx.octx = octx; grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); - grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; + if (use_dma) { + grctx.chunks_per_row = 1; + grctx.chunk_size = ne00; + grctx.total_tasks = nr; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1); + + const uint32_t n_threads = MIN(nr, octx->n_threads); + grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads); + } else { + uint32_t chunks_per_row = 1; + uint32_t chunk_size = ne00; + uint32_t total_tasks = nr; + + if (nr < octx->n_threads) { + const uint32_t min_chunk_size = 1024; + uint32_t max_chunks = ne00 / min_chunk_size; + if (max_chunks == 0) { + max_chunks = 1; + } + chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks); + chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row; + total_tasks = nr * chunks_per_row; + } + + grctx.chunks_per_row = chunks_per_row; + grctx.chunk_size = chunk_size; + grctx.total_tasks = total_tasks; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row); - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); + const uint32_t n_threads = MIN(total_tasks, octx->n_threads); + grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads); + } return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 9e1b778b01f..a496f6289ae 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -50,8 +50,8 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong - const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf - const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] @@ -1278,7 +1278,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { struct hmx_fa_context factx; memset(&factx, 0, sizeof(factx)); factx.octx = octx; - factx.n_threads = octx->ctx->n_threads; + factx.n_threads = n_threads; factx.DK = DK; factx.DV = DV; factx.n_kv = nek1; @@ -1328,10 +1328,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); // ======== VTCM allocation (GQA-aware) ======== + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); - const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096); + const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096); const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); @@ -1401,11 +1406,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== DMA setup ======== dma_queue * const dma = ctx->dma[0]; - // Padded row sizes for DMA - const size_t size_k_row = nek0 * sizeof(__fp16); - const size_t size_v_row = nev0 * sizeof(__fp16); - const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); - const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + // Padded row sizes for DMA (defined in outer scope) const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6fe3e6c7d85..51f9243ce0a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -104,6 +104,7 @@ int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); +int op_concat(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 9d905a30133..54cfadd9b0a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -89,6 +89,7 @@ enum htp_op_code { HTP_OP_TRI, HTP_OP_PAD, HTP_OP_NORM, + HTP_OP_CONCAT, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h new file mode 100644 index 00000000000..c5b9a5d47c1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h @@ -0,0 +1,90 @@ +#ifndef HVX_SIN_COS_H +#define HVX_SIN_COS_H + +#include "hvx-base.h" +#include "hvx-floor.h" + +static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for cos(y) + HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f); + HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f); + HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f); + HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f); + HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f); + + HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4)); + cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y)); + + return hvx_vec_mul_f32_f32(cos_y, sign); +} + +static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for sin(y) + HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f); + HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f); + HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f); + HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f); + HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f); + + HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4)); + sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_mul_f32_f32(y, sin_y); + + return hvx_vec_mul_f32_f32(sin_y, sign); +} + +#endif /* HVX_SIN_COS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index e0452811ec3..0a760cd344c 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -14,6 +14,8 @@ #include "hvx-sqrt.h" #include "hvx-arith.h" #include "hvx-div.h" +#include "hvx-floor.h" +#include "hvx-sin-cos.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e8619388478..f3a0866c7cd 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -420,8 +420,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(128); + ctx->dma[i] = dma_queue_create(256); // queue depth } // init worker pool @@ -601,6 +600,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_PAD: return op_pad(octx); + case HTP_OP_CONCAT: + return op_concat(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index b398e19f06e..c839044b84f 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -7,6 +7,7 @@ #include #include +#include #include "hex-dma.h" #include "hvx-utils.h" @@ -75,6 +76,9 @@ struct htp_rope_context { size_t theta_cache_offset; uint32_t src0_nrows; + struct fastdiv_values div_ne2_ne1; + struct fastdiv_values div_ne1; + uint64_t t_start; }; @@ -117,13 +121,84 @@ static __attribute__((noinline)) void rope_cache_init(const float theta_base, float * cache, const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; +#if __HVX_ARCH__ >= 79 + const bool is_v79_or_newer = true; +#else + const bool is_v79_or_newer = false; +#endif + + if (is_v79_or_newer && ext_factor == 0.0f) { + // Fast path: fully vectorized + // We process 32 pairs (64 elements) per iteration. + const uint32_t n_blocks = ne0 / 64; + + // Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31] + float __attribute__((aligned(128))) theta_powers[32]; + theta_powers[0] = 1.0f; + for (int j = 1; j < 32; j++) { + theta_powers[j] = theta_powers[j - 1] * theta_scale; + } + HVX_Vector v_theta_powers = hvx_vmem(theta_powers); - for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; - rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale); + HVX_Vector v_mscale = hvx_vec_splat_f32(mscale); + + // Base theta starts at theta_base + float theta_block = theta_base; + // The scale factor for the next block is theta_scale^32 + float theta_scale_32 = 1.0f; + for (int j = 0; j < 32; j++) { + theta_scale_32 *= theta_scale; + } + + for (uint32_t b = 0; b < n_blocks; b++) { + uint32_t i0 = b * 64; + HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block); + HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers); + + if (freq_factors) { + // Load 32 elements of freq_factors + HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2); + HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff); + v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff); + } + + HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale); + + HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final); + HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final); + + vcos = hvx_vec_mul_f32_f32(vcos, v_mscale); + vsin = hvx_vec_mul_f32_f32(vsin, v_mscale); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4); - theta *= theta_scale; + if (((uintptr_t)cache) % 128 == 0) { + hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore); + hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore); + } else { + hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore)); + hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore)); + } + + theta_block *= theta_scale_32; + } + + // Leftovers + float theta = theta_block; + for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } else { + // Fallback to original scalar loop + float theta = theta_base; + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } } } @@ -195,24 +270,18 @@ static void rope_corr_dims(int n_dims, } static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 + const uint32_t he = ne / 2; + const uint32_t nvec = he / 32; + const uint32_t nloe = he % 32; - uint32_t he = ne / 2; // half_dims offset in elements - uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i]; + HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32); - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i += 2) { - HVX_Vector v0 = vsrc[i/2+0]; - HVX_Vector v1 = vsrc[i/2+hv]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; - - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); @@ -222,37 +291,45 @@ static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * rest HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); - vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + ((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4); + hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i/2]; - float x1 = src0[i/2 + he]; - dst[i/2] = x0 * cos_theta - x1 * sin_theta; - dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32); + HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32); + + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1]; + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4)); + hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5)); } } static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two + const uint32_t nvec = ne / 64; + const uint32_t nloe = ne % 64; - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i+=2) { - HVX_Vector v0 = vsrc[i+0]; - HVX_Vector v1 = vsrc[i+1]; + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0]; + HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); @@ -264,17 +341,52 @@ static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - vdst[i+0] = Q6_V_lo_W(vstore); - vdst[i+1] = Q6_V_hi_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i+0]; - float x1 = src0[i+1]; - dst[i+0] = x0 * cos_theta - x1 * sin_theta; - dst[i+1] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + if (nloe <= 32) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore)); + } else { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32); + + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + ((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore); + hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore)); + } } } @@ -348,13 +460,19 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { const int32_t * pos = (const int32_t *) src1->data; const float * freq_factors = src2 ? (const float *) src2->data : NULL; - uint32_t ir = 0; + const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1); + const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1); + const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1); + const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1); + + uint32_t ir = src0_start_row; uint32_t prev_i2 = (uint32_t) -1; - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads - if (ir < src0_start_row) { ir++; i1++; continue; } + for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch + const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0; + for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len + const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0; + for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads if (ir >= src0_end_row) goto done; // Rows in this block @@ -407,9 +525,6 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); } - - // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, - // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } // Skip output DMA transactions from prev block (if any) @@ -489,7 +604,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { // Aligned row sizes for VTCM const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256); // Calculate spad sizes per thread size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; @@ -546,6 +661,11 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.src0_nrows = src0_nrows; rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + if (src0_nrows > 0) { + rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]); + rctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + } + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 0def7b408bf..58c54967db0 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -65,6 +65,9 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); @@ -109,6 +112,9 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 40d2d60153a..7d0431d8ba8 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -207,7 +207,7 @@ static void hvx_fast_norm_f32(const uint8_t * restrict src, // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); - HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v)); #pragma unroll(4) for (int i = 0; i < nvec; i++) { From a0efd13f0fe9e2123a5d04f57bb353225c5f4453 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 26 May 2026 08:48:05 -0500 Subject: [PATCH 165/289] vulkan: optimize conv2d and implement coopmat1 support (llama/22620) * vulkan: add CONV_SHAPE_64x128 for medium-K conv2d * vulkan: skip conv2d bounds checks when shapes align with tile sizes * vulkan: use WG_SIZE=128 for CONV_SHAPE_64x32 conv2d * vulkan: stage cm2 conv2d accumulator through shmem before global store * vulkan: add coopmat1 conv2d path * fallback when using too much shared memory. clean up comments * Require 16x16x16 and subgroup size 32 or 64 * check whether shared memory is sufficient before overwriting conv2d params with coopmat1 values --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 119 +++++++++++-- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 159 ++++++++++++++++-- .../vulkan-shaders/vulkan-shaders-gen.cpp | 12 +- 3 files changed, 264 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index aa289220a90..18d7cedad4b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -398,6 +398,7 @@ enum vk_conv_shapes { CONV_SHAPE_128x128, CONV_SHAPE_64x32, CONV_SHAPE_32x256, + CONV_SHAPE_64x128, CONV_SHAPE_COUNT, }; @@ -412,6 +413,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = { { 128, 128, 16 }, // CONV_SHAPE_128x128 { 64, 32, 32 }, // CONV_SHAPE_64x32 { 32, 256, 16 }, // CONV_SHAPE_32x256 + { 64, 128, 16 }, // CONV_SHAPE_64x128 }; enum dmmv_wg_sizes { @@ -447,14 +449,16 @@ struct vk_fa_pipeline_state { }; struct vk_conv2d_pipeline_state { - vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH) - : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {} + vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned) + : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {} uint32_t s0, s1, p0, p1, d0, d1, KW, KH; + // when set, shader can skip K/CRS/NPQ bounds checks and address clamps + uint32_t aligned; bool operator<(const vk_conv2d_pipeline_state &b) const { - return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) < - std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH); + return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) < + std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned); } }; @@ -4934,7 +4938,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { - uint32_t conv2d_WG_SIZE = 256; + // smaller WG for the small-tile fallback gives more concurrent WGs per SM + uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8; uint32_t conv2d_SHMEM_PAD = 4; @@ -4973,18 +4978,77 @@ static void ggml_vk_load_shaders(vk_device& device) { conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. } - uint32_t conv2d_shmem_req = - (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float); - if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + // cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size). + // Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader. + // Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need + // subgroup_size_control to force the driver to actually use it. + bool conv2d_use_cm1 = false; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + conv2d_use_cm1 = !device->coopmat2 && + device->coopmat_support && device->coopmat_support_16x16x16_f16acc && + device->subgroup_size_control && + (device->subgroup_size == 32 || device->subgroup_size == 64) && + s != CONV_SHAPE_128x128; +#endif + + const uint32_t conv2d_cm1_shmem_pad = 8; + + auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) { + const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float); + const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u; + return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; + }; + + // coopmat1 needs to store the output through shared memory, so check up front + // whether it'll fit and disable it before applying coopmat1 parameters. + if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { + conv2d_use_cm1 = false; + } + + uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise + if (conv2d_use_cm1) { + conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad; + // 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256 + // (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64). + const bool sg64 = (device->subgroup_size == 64); + switch (s) { + case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break; + case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break; + case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break; + default: break; + } + const uint32_t warps_M = conv2d_BS.K / conv2d_WM; + const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN; + conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size; + } + + // stage cm2 accumulator through shmem for coalesced global stores; + // skipped on 128x128 where the extra Csh footprint hurts occupancy. + // cm1 always uses the staged path. + uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u; + if (conv2d_use_cm1) { + conv2d_csh_store = 1; + } + + // shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar + const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1; + + // shrink CRS if the non-cm1 config still doesn't fit + if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) { + GGML_ASSERT(!conv2d_use_cm1); conv2d_BS.CRS = 8; if (use_collectives) { conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS); } + conv2d_csh_store = 0; } std::array wg_denoms = { conv2d_BS.K, 1, 1 }; std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + // cm1 needs a fixed subgroup width to match the WG_SIZE we computed + const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0; + #define CREATE_CONV(name, type_suffix, spv_suffix) \ for (auto &c : device->pipeline_##name##type_suffix[s]) { \ const vk_conv2d_pipeline_state &state = c.first; \ @@ -4997,10 +5061,14 @@ static void ggml_vk_load_shaders(vk_device& device) { spec_constants_cpy.push_back(state.d1); \ spec_constants_cpy.push_back(state.KW); \ spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ ggml_vk_create_pipeline( \ device, c.second, #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \ } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ @@ -5011,6 +5079,11 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->coopmat2) { CREATE_CONVS(_cm2) } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONVS(_cm1) + } else #endif if (conv2d_UNROLL) { CREATE_CONVS(_unroll) @@ -9473,10 +9546,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u // so small convolutions will still choose a smaller tile. const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; - if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { + // 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile. + bool allow_128x128 = true; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) { + allow_128x128 = false; + } +#endif + + if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { return CONV_SHAPE_128x128; } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) { return CONV_SHAPE_32x256; + } else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + return CONV_SHAPE_64x128; + } else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + // cm1 fallback for large K when 128x128 isn't available + return CONV_SHAPE_64x128; } else { return CONV_SHAPE_64x32; } @@ -10008,7 +10094,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0; uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1; uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1; - vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); + + // tile-aligned shapes let the shader skip bounds checks + const uint32_t Cin = (uint32_t)src1->ne[2]; + const uint32_t CRS = Cin * KW * KH; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((K % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned); std::map *pipelines = nullptr; if (op == GGML_OP_CONV_2D) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 875c012cd3b..1428ef68d81 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -7,6 +7,13 @@ #extension GL_KHR_memory_scope_semantics : enable #endif +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + #ifdef USE_COLLECTIVES # extension GL_KHR_shader_subgroup_shuffle : enable #endif @@ -77,6 +84,39 @@ layout(constant_id = 12) const uint d1 = 1; // Kernel spatial sizes layout(constant_id = 13) const uint KW = 1; layout(constant_id = 14) const uint KH = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 15) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 16) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 17) const uint WM = 32; +layout(constant_id = 18) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, H_idx/W_idx are in bounds by construction (non-TRANSPOSE only) +#ifdef TRANSPOSE +const bool hw_in_bounds = false; +#else +const bool hw_in_bounds = (p0 == 0) && (p1 == 0); +#endif + +// TRANSPOSE stride alignment is trivially satisfied for stride 1 +#ifdef TRANSPOSE +const bool stride_in_bounds = (s0 == 1) && (s1 == 1); +#else +const bool stride_in_bounds = true; +#endif uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -94,7 +134,7 @@ uint32_t n_elems_out = K * NPQ; // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); -#ifdef COOPMAT2 +#if defined(COOPMAT2) || defined(COOPMAT) #define SHMEM_TYPE float16_t #else #define SHMEM_TYPE float @@ -112,6 +152,17 @@ const uint32_t Bsh_len = BS_CRS * Bsh_stride; shared SHMEM_TYPE Ash[Ash_len]; // K x CRS shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + // Threadtile sizes const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; @@ -161,7 +212,7 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = D_TYPE(elem); } return elem; @@ -176,6 +227,13 @@ void main() { #ifdef COOPMAT2 coopmat matC; matC = coopmat(0.0); +#elif defined(COOPMAT) + coopmat sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; #else float regC[TS_K][TS_NPQ]; for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -228,12 +286,15 @@ void main() { uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ #ifdef TRANSPOSE - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03; #else - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03; #endif + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } float val = knl_data[knl_idx]; - if (K_idx >= K || CRS_idx_a >= CRS) { + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { val = 0.0; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); @@ -282,15 +343,27 @@ void main() { uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1; uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0; #endif - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + uint32_t src_idx = W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !hw_in_bounds || !stride_in_bounds) { + src_idx = min(max(src_idx, 0), p.Cin * p.N * p.W * p.H - 1); + } float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ - || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!hw_in_bounds && (H_idx >= p.H || W_idx >= p.W)) { + oob = true; + } #ifdef TRANSPOSE - || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0) + if (!stride_in_bounds && + ((H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0))) { + oob = true; + } #endif - ) { + if (oob) { val = 0.0; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); @@ -303,6 +376,23 @@ void main() { coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } #else if (T_y * TS_K < K) { UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { @@ -325,8 +415,51 @@ void main() { barrier(); } /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } #ifdef COOPMAT2 - coopMatPerElementNV(matC, matC, perElemOpStore); + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif #else if (T_y * TS_K < K) { for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -337,7 +470,7 @@ void main() { uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = regC[T_ly][T_lx]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a1d735150fd..a0aac391298 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -984,8 +984,16 @@ void process_shaders() { string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (unroll) { - defines["COOPMAT2"] = "1"; - string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm1_defines, true, true, false); } #endif } From 6a249cd6400a2be44f2fdd5d38248aa2b36d5f92 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Wed, 27 May 2026 01:59:35 +0300 Subject: [PATCH 166/289] ggml-zendnn : fixed naming of matmul function (llama/20964) * ggml-zendnn: fixed naming of matmul function * ggml-zendnn: fixed naming of mul_mat_id function * ggml-zendnn: fixed print in mul_mat_id --------- Co-authored-by: plotnikov.v10 --- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6051d082003..3c33dcb11a0 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -88,7 +88,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int return true; } -static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, +static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, const void * A, int64_t lda, const void * B, int64_t ldb, void * C, int64_t ldc, int Atype, int Btype, int Ctype) { @@ -200,7 +200,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m ne11, // n ne10, // k @@ -213,7 +213,7 @@ static void ggml_zendnn_compute_forward_mul_mat( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } } } @@ -355,7 +355,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } // batched gemm for all tokens in this expert - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m cne1, // n ne10, // k @@ -368,7 +368,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } // scatter output rows to destination From 80e87ec453081f649903c70168cf3279fe455eff Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Wed, 27 May 2026 17:48:40 +0800 Subject: [PATCH 167/289] vulkan: avoid preferring transfer queue on AMD UMA devices (llama/22455) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 18d7cedad4b..f45b9cfd1e9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5841,8 +5841,12 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); - // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled - const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + // Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled. + const bool prefers_transfer_queue = + device->vendor_id == VK_VENDOR_ID_AMD && + device->architecture != AMD_GCN && + !device->uma && + !allow_graphics_queue; if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; From 98c6722fecccfca0c6ac947b487888bf375b90a2 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 27 May 2026 14:21:04 +0200 Subject: [PATCH 168/289] CUDA: restrict PDL to CTK >= 12.3 due to MSVC issues (llama/23742) --- ggml/src/ggml-cuda/common.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e54ecb29308..50d7763dcdd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -110,11 +110,14 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 -// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. +// However, this has been bugged in CTK < 12.3 for MSVC builds, see +// https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 // __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && \ + (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) # define GGML_CUDA_USE_PDL -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) static __device__ __forceinline__ void ggml_cuda_pdl_sync() { #if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER From c5cde8c7171dea86345338a7dee64d939b4c09cf Mon Sep 17 00:00:00 2001 From: l8bloom Date: Wed, 27 May 2026 16:59:08 +0200 Subject: [PATCH 169/289] vulkan: add REPEAT op support for f16 to f16. (llama/23298) * feat: extend repeat op for vulkan * feat: add repeat_f16 vulkan pipeline * fix: ensure same dst and src types * fix: use type_size instead of data types * fix: use int16 and int32 for repeat shader op * chore: rename repeat_f* to repeat_i* * chore: rename repeat vulkan pipelines --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 15 +++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f45b9cfd1e9..99b42f3bdf0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -768,7 +768,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; - vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -4708,9 +4709,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_UNARY(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9738,7 +9741,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { - return ctx->device->pipeline_repeat_f32; + return ctx->device->pipeline_repeat_i32; + } + if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) { + return ctx->device->pipeline_repeat_i16; } return nullptr; case GGML_OP_REPEAT_BACK: @@ -16253,7 +16259,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_REPEAT: - return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) && + (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2); case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a0aac391298..24b9d25f733 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -798,9 +798,11 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); From 1b590bbb9ae834d31d1116e804249296bc83762c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 27 May 2026 10:18:28 -0500 Subject: [PATCH 170/289] vulkan: use GL_NV_cooperative_matrix_decode_vector for faster matmul (llama/23541) --- ggml/src/ggml-vulkan/CMakeLists.txt | 6 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 189 +++++- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 4 + .../vulkan-shaders/dequant_funcs_cm2.glsl | 608 ++++++++++++++++++ .../feature-tests/coopmat2_decode_vector.comp | 7 + .../vulkan-shaders/flash_attn_cm2.comp | 42 +- .../vulkan-shaders/mul_mm_cm2.comp | 8 +- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 7 + 8 files changed, 865 insertions(+), 6 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 65785ae4566..2d9e85794ad 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -79,6 +79,12 @@ if (Vulkan_FOUND) "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) + test_shader_extension_support( + "GL_NV_cooperative_matrix_decode_vector" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp" + "GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT" + ) + test_shader_extension_support( "GL_EXT_integer_dot_product" "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99b42f3bdf0..fb07282ef76 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include +// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the +// installed Vulkan headers predate the extension. +#ifndef VK_NV_cooperative_matrix_decode_vector +#define VK_NV_cooperative_matrix_decode_vector 1 +#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000) +typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { + VkStructureType sType; + void* pNext; + VkBool32 cooperativeMatrixDecodeVector; +} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV; +#endif + // SPIR-V Headers: different SDK installations expose different include paths. // LunarG Vulkan SDK on Windows typically provides . // Linux packages, MSYS2 and MinGW often use the Khronos layout . @@ -678,6 +691,7 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_decode_vector; bool pipeline_executable_properties_support {}; @@ -2167,6 +2181,136 @@ static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; +static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; +static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; +static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; + +// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it +// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the +// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the +// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction. +// Returns true when the input used the extension (and `out` was populated with a +// stripped copy); returns false otherwise without touching `out`. +static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector & out) { + static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector"; + + if (word_count < 5) { + return false; + } + + bool uses_decode_vector = false; + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + uses_decode_vector = true; + break; + } + } + pos += wc; + } + + if (!uses_decode_vector) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector"); + + // Bulk-copy unchanged runs and only break the run when an instruction needs to + // be dropped or patched. Use reserve + insert/push_back so the destination buffer + // is touched exactly once (no zero-initialization pass from resize()). + out.clear(); + out.reserve(word_count); + + size_t run_start = 0; + auto flush_run = [&](size_t up_to) { + if (up_to > run_start) { + out.insert(out.end(), code + run_start, code + up_to); + } + }; + + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + } + + if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + + if (op == kSpvOpCooperativeMatrixLoadTensorNV) { + // [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...] + GGML_ASSERT(wc >= 8); + + uint32_t mem_mask = code[pos + 6]; + size_t cur = pos + 7; + // Each of these MemoryAccess bits (when set) carries one trailing operand. + cur += (mem_mask & 0x2) ? 1 : 0; // Aligned + cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable + cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible + cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask + cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask + GGML_ASSERT(cur < pos + wc); + + uint32_t ta_mask = code[cur]; + if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) { + pos += wc; + continue; // leave instruction inside the current unchanged run + } + + flush_run(pos); + + // Append unchanged prefix of the instruction (header through the mem-extras). + size_t inst_start = out.size(); + size_t pre_n = cur - pos; + out.insert(out.end(), code + pos, code + pos + pre_n); + + // Emit TA mask with the DecodeVectorFunc bit cleared. + out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit); + + // TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim; + // DecodeVectorFunc (0x4) is dropped along with its trailing id operand. + size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0); + if (keep_ta_extras) { + out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras); + } + + GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1); + + // Patch the instruction header with the new (one-shorter) word count. + uint32_t new_wc = wc - 1; + out[inst_start] = (new_wc << spv::WordCountShift) | op; + + pos += wc; + run_start = pos; + continue; + } + + pos += wc; + } + + flush_run(word_count); + return true; +} + static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { @@ -2238,6 +2382,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (device->coopmat2 && !device->coopmat2_decode_vector) { + const uint32_t * src = spirv.empty() ? reinterpret_cast(spv_data) : spirv.data(); + size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size(); + std::vector stripped; + if (ggml_vk_strip_decode_vector(src, src_n, stripped)) { + spirv = std::move(stripped); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); + } + } +#endif + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -5159,6 +5315,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; @@ -5193,6 +5350,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -5470,6 +5630,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME); + } + #if defined(VK_KHR_shader_bfloat16) VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; bfloat16_features.pNext = nullptr; @@ -5629,6 +5797,7 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } #endif @@ -5915,6 +6084,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; @@ -5933,6 +6103,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -6017,6 +6190,13 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -6041,7 +6221,14 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); - std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; +#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + coopmat2_decode_vector_support = false; +#endif + + std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2") + : coopmat_support ? "KHR_coopmat" + : "none"; std::string device_name = props2.properties.deviceName.data(); GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index e1f613fb4f6..10a9ea21025 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) message(STATUS "Enabling coopmat2 glslc support") endif() +if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 decode_vector glslc support") +endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) message(STATUS "Enabling dot glslc support") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index c582aba87dc..7171cbfa559 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,4 +1,12 @@ +// Each format defines a scalar dequantFunc plus a V=4 dequantFunc_v +// passed as the optional vector decoder to coopMatLoadTensorNV via +// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support +// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V. +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif + #include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { @@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2 return bit != 0u ? d : -d; } +f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t md = -d; + const uint idx = coordInBlock[1]; + const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u); + return f16vec4( + (qs_nib & 1u) != 0u ? d : md, + (qs_nib & 2u) != 0u ? d : md, + (qs_nib & 4u) != 0u ? d : md, + (qs_nib & 8u) != 0u ? d : md); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6} + const uint qsw = uint32_t(bl.block.qs[qs_i ]) + | (uint32_t(bl.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu; + const u8vec4 q = unpack8(q4); + return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { block_q4_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 { + block_q4_1_packed32 block; +}; + float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { block_q5_0 block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 { + block_q5_0_packed16 block; +}; + float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]); + const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { block_q5_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 { + block_q5_1_packed32 block; +}; + float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { block_q8_0_packed16 block; }; @@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint base = idx >> 1u; + const uint w = uint(uint16_t(bl.block.qs[base])) + | (uint(uint16_t(bl.block.qs[base + 1u])) << 16u); + const i8vec4 qi = unpack8(int32_t(w)); + return f16vec4(vec4(qi) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { block_q2_K block; }; @@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 block_q2_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 { + block_q2_K_packed32 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); @@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl); + const f16vec2 dm = bl.block.dm; + const uint idx = coordInBlock[1]; + + const uint scalesi = idx >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + // qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0, + // so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2). + const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2); + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const uint qs4 = (qsw >> qsshift) & 0x03030303u; + const u8vec4 qi = unpack8(qs4); + + const uint scales = bl.block.scales[scalesi]; + const float16_t d_sub = dm.x * float16_t(scales & 0xF); + const float16_t m_sub = dm.y * float16_t(scales >> 4); + return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { block_q3_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 { + block_q3_K_packed16 block; +}; + float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint n = idx >> 7; // 0,1 + const uint is = idx >> 4; // 0..15 + const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3 + const uint qsshift = halfsplit << 1; // 0,2,4,6 + const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte) + + uint32_t scaleidx0 = (is < 8) ? is : (is - 8); + uint32_t scaleidx0shift = (is < 8) ? 0u : 4u; + uint32_t scaleidx1 = is + 8 - (is / 4) * 4; + uint32_t scaleidx1shift = (is / 4) * 2; + + const int8_t us = int8_t( + ((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | + (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + const float16_t dl = bl.block.d * float16_t(int(us) - 32); + + // For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4. + const uint qsi = (n << 5) + (idx & 0x1Cu); + const uint hmi = (idx & 0x1Cu); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qsw / hmw holds the data for element idx+j. + const uint qsw = uint32_t(bl16.block.qs[qsi >> 1]) + | (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16); + const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1]) + | (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16); + + // qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits + // with no inter-byte leakage. + const uint ql4 = (qsw >> qsshift) & 0x03030303u; + const uint qh4 = (hmw >> hbit) & 0x01010101u; + + const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4); + return f16vec4(vec4(q) * vec4(float(dl))); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { block_q4_K block; }; @@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 { + block_q4_K_packed32 block; +}; + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { block_q4_K_packed128 block; }; @@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + // idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index: + // (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a + // single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage. + const uint sh = (idx & 0x20u) >> 3u; + const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]); + const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu); + + return f16vec4(vec4(d) * vec4(q) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { block_q5_K block; }; @@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed128 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 { + block_q5_K_packed32 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); @@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + const uint is = idx >> 5; + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + // sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage). + const uint sh = (idx & 0x20u) >> 3u; + const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2); + const uint qh_w = (idx & 0x1Eu) >> 2; + + const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu; + // qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4. + const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u; + + const u8vec4 qi = unpack8(ql4 | qh4); + return f16vec4(vec4(qi) * vec4(d) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { block_q6_K block; }; @@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = idx >> 4; + const uint sh = b * 4; // 0 or 4 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1); + const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qlw / qhw holds the data for element idx+j. + const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16); + const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16); + + // sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the + // wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position. + const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu; + const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u; + + const ivec4 qi = ivec4(unpack8(ql4 | qh4)); + return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale))); +} + #if defined(DATA_A_IQ1_S) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { block_iq1_s block; @@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const int i8b = int(idx & 4); // 0 or 4 + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * dl); +} #endif #if defined(DATA_A_IQ1_M) @@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = idx >> 3; + const uint ib16 = idx >> 4; + const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group + + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0; + const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl)); +} #endif #if defined(DATA_A_IQ2_XXS) @@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = (idx & 0x18) >> 3; + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_XS) @@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; + const uint sshift = (idx & 0x10) >> 2; + const uint iqs = idx >> 3; + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_S) @@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + + const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_XXS) @@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3); + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1])); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + + const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u); + + const uint grid = iq3xxs_grid[qs]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_S) @@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint iqh = idx >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u); + const uint scale = bl.block.scales[iqs / 16]; + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + + const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ4_XS) @@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_xs block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 { + block_iq4_xs_packed32 block; +}; + float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); return ret; } + +f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; // 0..7 + const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF; + const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3; + const uint qshift = (idx & 0x10) >> 2; // {0, 4} + const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32) + + const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32); + + const uint qsw = bl32.block.qs[qs_w]; + const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_iq4nl[qv.x]), + float(kvalues_iq4nl[qv.y]), + float(kvalues_iq4nl[qv.z]), + float(kvalues_iq4nl[qv.w])) * float(dl); + return f16vec4(ret); +} #endif #if defined(DATA_A_IQ4_NL) @@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_nl block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 { + block_iq4_nl_packed16 block; +}; + float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; return ret; } + +f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4( + float(d) * float(kvalues_iq4nl[q.x]), + float(d) * float(kvalues_iq4nl[q.y]), + float(d) * float(kvalues_iq4nl[q.z]), + float(d) * float(kvalues_iq4nl[q.w])); +} #endif #if defined(DATA_A_MXFP4) @@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } + +f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uvec4 qv = uvec4( + uint(bl.block.qs[iqs]), + uint(bl.block.qs[iqs + 1u]), + uint(bl.block.qs[iqs + 2u]), + uint(bl.block.qs[iqs + 3u])); + qv = (qv >> shift) & 0xFu; + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_NVFP4) @@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF block_nvfp4 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 { + block_nvfp4_packed32 block; +}; + float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords qs = (qs >> shift) & 0xF; return float16_t(kvalues_mxfp4[qs] * d * 0.5); } + +f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl); + const uint idx = coordInBlock[1]; + const uint sub = idx >> 4; + const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8) + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 +#define dequantFuncA_v dequantFuncQ1_0_v #elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 +#define dequantFuncA_v dequantFuncQ4_0_v #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 +#define dequantFuncA_v dequantFuncQ4_1_v #elif defined(DATA_A_Q5_0) #define dequantFuncA dequantFuncQ5_0 +#define dequantFuncA_v dequantFuncQ5_0_v #elif defined(DATA_A_Q5_1) #define dequantFuncA dequantFuncQ5_1 +#define dequantFuncA_v dequantFuncQ5_1_v #elif defined(DATA_A_Q8_0) #define dequantFuncA dequantFuncQ8_0 +#define dequantFuncA_v dequantFuncQ8_0_v #elif defined(DATA_A_Q2_K) #define dequantFuncA dequantFuncQ2_K +#define dequantFuncA_v dequantFuncQ2_K_v #elif defined(DATA_A_Q3_K) #define dequantFuncA dequantFuncQ3_K +#define dequantFuncA_v dequantFuncQ3_K_v #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define dequantFuncA_v dequantFuncQ4_K_v #define fetch_scales fetch_scalesQ4_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define dequantFuncA_v dequantFuncQ5_K_v #define fetch_scales fetch_scalesQ5_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#define dequantFuncA_v dequantFuncQ6_K_v #elif defined(DATA_A_IQ1_S) #define dequantFuncA dequantFuncIQ1_S +#define dequantFuncA_v dequantFuncIQ1_S_v #elif defined(DATA_A_IQ1_M) #define dequantFuncA dequantFuncIQ1_M +#define dequantFuncA_v dequantFuncIQ1_M_v #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS +#define dequantFuncA_v dequantFuncIQ2_XXS_v #elif defined(DATA_A_IQ2_XS) #define dequantFuncA dequantFuncIQ2_XS +#define dequantFuncA_v dequantFuncIQ2_XS_v #elif defined(DATA_A_IQ2_S) #define dequantFuncA dequantFuncIQ2_S +#define dequantFuncA_v dequantFuncIQ2_S_v #elif defined(DATA_A_IQ3_XXS) #define dequantFuncA dequantFuncIQ3_XXS +#define dequantFuncA_v dequantFuncIQ3_XXS_v #elif defined(DATA_A_IQ3_S) #define dequantFuncA dequantFuncIQ3_S +#define dequantFuncA_v dequantFuncIQ3_S_v #elif defined(DATA_A_IQ4_XS) #define dequantFuncA dequantFuncIQ4_XS +#define dequantFuncA_v dequantFuncIQ4_XS_v #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#define dequantFuncA_v dequantFuncIQ4_NL_v #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#define dequantFuncA_v dequantFuncMXFP4_v #elif defined(DATA_A_NVFP4) #define dequantFuncA dequantFuncNVFP4 +#define dequantFuncA_v dequantFuncNVFP4_v #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp new file mode 100644 index 00000000000..65e9c678401 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix_decode_vector : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 141bb870883..6d45b4931df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -54,6 +57,41 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const } } +// V=4 vector decode for K/V; dispatches to per-format _v decoders. +f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +#ifdef GL_NV_cooperative_matrix_decode_vector +#define FADECODEK , faDecodeK, faDecodeKVector +#define FADECODEV , faDecodeV, faDecodeVVector +#else +#define FADECODEK , faDecodeK +#define FADECODEV , faDecodeV +#endif + layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; @@ -259,7 +297,7 @@ void main() { // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } @@ -325,7 +363,7 @@ void main() { uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 497a18ff8a7..250d708479b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -71,10 +71,12 @@ layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 -#define DECODEFUNCA , dequantFuncA - #include "dequant_funcs_cm2.glsl" - +#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector) +#define DECODEFUNCA , dequantFuncA, dequantFuncA_v +#else +#define DECODEFUNCA , dequantFuncA +#endif #else #define DECODEFUNCA #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4bcd97756fd..06eff6f219f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1722,11 +1722,18 @@ struct block_nvfp4 uint8_t qs[QUANT_K_NVFP4 / 2]; }; +struct block_nvfp4_packed32 +{ + uint32_t d[QUANT_K_NVFP4 / 16 / 4]; + uint32_t qs[QUANT_K_NVFP4 / 2 / 4]; +}; + #if defined(DATA_A_NVFP4) #define QUANT_K QUANT_K_NVFP4 #define QUANT_R QUANT_R_NVFP4 #define QUANT_AUXF 1 #define A_TYPE block_nvfp4 +#define A_TYPE_PACKED32 block_nvfp4_packed32 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) From 8bce478ee8be5cc5f78d6d38cbacdd4f6f1ae64e Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Wed, 27 May 2026 15:19:23 +0000 Subject: [PATCH 171/289] vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 (llama/22887) * vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 Against mesa git, this shows a 4.8% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. Note that this breaks some tests until the last commit which fixes OOB A reads. * vulkan: Use aligned loads in mul_mat_vec when available Against mesa git, this shows a 3.3% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. * Make explicit that `num_rows` is <= `NUM_ROWS` in mul_mat_vec Mesa's UUB logic can't see through conditionals, limiting its ability to understand the bounds on the `num_rows` field in the cleanup run. Making it explicit that `num_rows` is, indeed, always <= `NUM_ROWS` helps mesa make slightly better codegen. Against mesa git, this currently shows a 1% performance improvement in tg128 on Qwen3.5-9B:BF16 on Intel BMG. * vulkan: Fix OOB A reads in MUL_MAT_VEC for odd sizes There was a TODO to fix the OOB reads from the A matrix which we do here. It is within performance noise (+<0.1%) in tg128 for Qwen3.5-9B:BF16 on Intel BMG. --- .../vulkan-shaders/dequant_funcs.glsl | 39 +++++ .../vulkan-shaders/mul_mat_vec.comp | 149 ++++++++++++++---- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 2 + 3 files changed, 163 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 88d07d2dfd5..e67299fdeca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -5,21 +5,60 @@ #include "types.glsl" #if defined(DATA_A_F32) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} + #endif #if defined(DATA_A_F16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const vec2 a = data_a_packed32[(a_offset + ib)/2]; + const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(a, b); +} #endif #if defined(DATA_A_BF16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return bf16_to_fp32(data_a[a_offset + ib]); +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]), + bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3])); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const uint a = data_a_packed32[(a_offset + ib)/2]; + const uint b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(uintBitsToFloat((a & 0x0000ffff) << 16), + uintBitsToFloat( a & 0xffff0000), + uintBitsToFloat((b & 0x0000ffff) << 16), + uintBitsToFloat( b & 0xffff0000)); +} #endif #if defined(DATA_A_Q4_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be4021b..5a9d0e778fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else -#define K_PER_ITER 2 +#define K_PER_ITER 4 #endif uint a_offset, b_offset, d_offset, y_offset; +vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) { + // Check if the latter elements are OOB, and don't fetch B or accumulate it. + OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols); + OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols); + OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols); + + if (!OOB_w) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3])); + } else if (!OOB_z) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + 0); + } else if (!OOB_y) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + 0, 0); + } else { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + 0, 0, 0); + } +} + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 + // Note that we end up fetching bogus elements here, but its fine as they'll be + // within an accessible block. const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); @@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); - - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); - } + bool OOB_y; + bool OOB_z; + bool OOB_w; + + const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w); #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); - - // matrix multiplication - temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); - if (!OOB) { - temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + if (!OOB_w) { + const vec4 v = dequantize4(ib, iqs, a_offset); + temp[j][n] += dot(v, b); + } else if (!OOB_z) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset); + const vec3 v = vec3(v0.x, v0.y, v1); + const vec3 b0 = vec3(b.x, b.y, b.z); + temp[j][n] += dot(v, b0); + } else if (!OOB_y) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 b0 = vec2(b.x, b.y); + temp[j][n] += dot(v0, b0); + } else { + const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset); + temp[j][n] = fma(v, b.x, temp[j][n]); } #endif } } } +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = 0; // quant index + const uint iybs = col; // y block start index + + const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + + const vec4 v = dequantize4_2aligned(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] += dot(v, b); + } + } +} +#endif + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); + const bool is_aligned_nonquant = + p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 && + p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 && + K_PER_ITER == 4; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 + uint i = 0; + +#if K_PER_ITER == 4 // If the K dimension is odd, we need lastiter==true on the last iteration // so OOB is computed correctly. Skip some unrolling to make that happen. - if ((p.ncols & 1) != 0 && + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } + if (is_aligned_nonquant) { + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { #endif - - uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && +#if K_PER_ITER == 4 + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } -#endif + if (is_aligned_nonquant) { + while (i < unrolled_iters && is_aligned_nonquant) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { +#endif while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif + +#if K_PER_ITER == 4 + if (is_aligned_nonquant) { + while (i < num_iters) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } else { +#endif while (i < num_iters) { iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); i++; } +#if K_PER_ITER == 4 + } +#endif reduce_result(temp, d_offset, first_row, num_rows, tid); } @@ -164,6 +259,6 @@ void main() { if (first_row >= p.stride_d) { return; } - compute_outputs(first_row, p.stride_d - first_row); + compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row)); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 06eff6f219f..f84d6f87334 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -31,6 +31,7 @@ #else #define A_TYPE float16_t #endif +#define A_TYPE_PACKED32 f16vec2 #endif #if defined(DATA_A_BF16) @@ -44,6 +45,7 @@ #else #define A_TYPE uint16_t #endif +#define A_TYPE_PACKED32 uint32_t #endif #define QUANT_K_Q4_0 32 From a52bd385d678e152774c211dc7a8ac372650558b Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 28 May 2026 01:48:12 +0900 Subject: [PATCH 172/289] ggml-webgpu: Fix how to dispatch WG to some ops (llama/23750) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 53 +++++++++++-------- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 10 ++-- .../wgsl-shaders/mul_mat_id_gather.wgsl | 43 +++++++-------- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f113da909ce..f6d17a073be 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, @@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, gathered_count_ids_binding_size), }; - const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - - const uint32_t gather_total_wg = param_n_expert; - const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); - const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + // n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B) + const uint32_t gather_wg_x = param_n_expert; dispatches.push_back({ - gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 } }); // params for mul_mat_id.wgsl @@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; uint32_t total_wg = wg_m * max_wg_n; - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); dispatches.push_back({ main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } @@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * block_size, npr, nrows }; - const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + uint32_t wg_x_init; + uint32_t wg_y_init; + const uint32_t total_wg_init = npr * nrows; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init); + std::vector init_entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) @@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; + uint32_t wg_x_merge; + uint32_t wg_y_merge; const uint32_t total_wg_merge = nm * nrows; - const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); - const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); + compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge); + dispatches.push_back({ argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } }); @@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index fa3bdf4e393..e268adfb16b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -49,12 +49,14 @@ struct Params{ var params: Params; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_index) gindex: u32, +) { + if (gindex >= params.ne) { return; } - var i = gid.x; + var i = gindex; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); let i2 = i / (params.src_ne1 * params.src_ne0); @@ -62,7 +64,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i1 = i / params.src_ne0; let i0 = i % params.src_ne0; - var j = gid.x; + var j = gindex; let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); let j2 = j / (params.dst_ne1 * params.dst_ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl index d79d5f3f282..581e922709d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -21,35 +21,32 @@ var count:atomic; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { + @builtin(local_invocation_id) local_id: vec3) { let thread_id = local_id.x; - let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + let own_expert = wg_id.x; // the expert assigned to this workgroup - if (own_expert < params.n_expert) { - if (thread_id == 0u) { - atomicStore(&count, 0); - } + if (thread_id == 0u) { + atomicStore(&count, 0); + } - workgroupBarrier(); - - for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { - let row = i / params.n_expert_used; - let col = i % params.n_expert_used; - let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); - if (own_expert == expert) { - let pos = atomicAdd(&count, 1u); - let gathered_id = own_expert * params.n_tokens + pos; - global_gathered_expert_used[gathered_id] = col; - global_gathered_tokens[gathered_id] = row; - } + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; } + } - workgroupBarrier(); + workgroupBarrier(); - if (thread_id == 0u) { - gathered_count_ids[own_expert] = atomicLoad(&count); - } + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); } } From 3bbe93378cc96339d362e4dbf490df10412ad389 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 27 May 2026 10:46:11 -0700 Subject: [PATCH 173/289] hexagon: add support for Q4_1 in MUL_MAT and MUL_MAT_ID (llama/23647) * hex-mm: add support for Q4_1 matmul/matvec, hvx-only for now * hmx-mm: add support for Q4_1 * hex-mm: use Q8_1 dynamic quantization to avoid having to compute sums in the vec_dot * hexagon: fix repack scratch buffer overflow * hex-mm: fix Q4_1 repack buffer sizing * hexagon: flip the build order for mm and fa (seems to help LTO) * hex-mm: add vec_dot 4x1s and minor HMX cleanup after adding Q4_1 * hex-mm: fix fp16 vec_dot fallback to 2x1 and another issue that could cause incorrect output * hexagon: resurrect early-wake and add support for polling for op-batch completions With Q4_1 ggml-hexagon now claims pretty much the entire graphs which gives the CPU more time to chilax. This is a good thing! But it does add extra latency for the pure benchmark runs. Early wakeup helps recover the latency a bit in the normals runs and op-batch polling is just for benchmarking. --------- Co-authored-by: Todor Boinovski --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 267 ++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 4 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 155 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 7 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 1769 ++++++++++++++++++-- 6 files changed, 2004 insertions(+), 200 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 1c8ecc197e9..5e8a4a740c1 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches +static int opt_oppoll = 0; // polling for batch completions static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -550,7 +551,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -611,7 +612,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -660,6 +661,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and offsets + for (int i = 0; i < nb; i++) { + ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + d_m[j * 2 + 0] = x[i * 8 + j].d; + d_m[j * 2 + 1] = x[i * 8 + j].m; + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; + bool partial = (nloe && i == nb-1); + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + // Unpack the scales and offsets + for (int i = 0; i < nb; i++) { + const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = d_m[j * 2 + 0]; + x[i * 8 + j].m = d_m[j * 2 + 1]; + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + // We still need to read and unpack the entire source row because quantization is block-based. + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // ======== Q8x4x2 ==================== static void dump_block_q8_0(const block_q8_0 * b, int i) { HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], @@ -876,7 +1110,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -937,7 +1171,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1238,7 +1472,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -1299,7 +1533,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1365,6 +1599,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1407,6 +1647,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1886,7 +2132,8 @@ void ggml_hexagon_session::flush_pending(bool all) { uint32_t n_dbufs; // Read response packet from queue - int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; } @@ -2327,6 +2574,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2377,6 +2625,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -3622,6 +3871,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, @@ -3634,6 +3885,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); @@ -3671,6 +3923,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 33d67dda9cc..d7927261a85 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 3ef0bcdb26d..ab5fd73380b 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; + // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( return r; } +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_dm = hvx_vmemu(scale_offset); + HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); + HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); + + HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); + HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); @@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1); const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. @@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + + sub_blk_base * scale_step; __fp16 *tile_bases[4]; for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } @@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } @@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step; HVX_Vector v_off = v_scat_base; // reset to column 0 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx( - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } (void) *(volatile HVX_Vector *)(tile_base); } else if (weight_type == HTP_TYPE_MXFP4) { @@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( // --- End x4x2 dequantizers --- +#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics + // requires external HMX lock static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 54cfadd9b0a..aadc77235ba 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,6 +20,7 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, @@ -28,6 +29,7 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, HTP_TYPE_Q8_0x4x2, HTP_TYPE_MXFP4x4x2, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f3a0866c7cd..7dd90ac7d7f 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; + if (i == (n_ops-1)) { + // wake up the host before starting the last op + dspqueue_write_early_wakeup_noblock(queue, 0, 0); + } + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); @@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } - // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - struct htp_opbatch_rsp rsp; rsp.id = req.id; rsp.status = HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 46fc5602dc9..7036c491bc4 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -40,6 +40,11 @@ struct htp_matmul_context { const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); + void (*vec_dot_4x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0); + // Precomputed values uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; @@ -155,6 +160,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); +} + static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -223,6 +235,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u return r; } +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = v0; + r.v[i*2+1] = v1; + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_V_lo_W(v0_1_p); + r.v[i*2+1] = Q6_V_hi_W(v0_1_p); + } + + return r; +} + static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -401,82 +469,96 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 return hvx_vec_rmpy_x8_partial(x, y, 512); } -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, +static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size @@ -486,11 +568,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -500,77 +582,306 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r2_ms = Q6_V_vand_QV(bmask, r2_ms); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r3_ms = Q6_V_vand_QV(bmask, r3_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { assert(n % 32 == 0); @@ -581,11 +892,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -595,9 +906,9 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums // Row sums (sf) - 4 accumulators for 2×2 tile HVX_Vector r0_c0_sum = Q6_V_vzero(); @@ -610,13 +921,13 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) + // Load src1 columns HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + // Load src0 rows + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -625,16 +936,38 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); // Compute combined scales HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); // Apply scales and accumulate HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); @@ -642,40 +975,72 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); @@ -686,10 +1051,15 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Reduce and store results @@ -700,26 +1070,26 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); @@ -729,12 +1099,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Apply scale to acc and accumulate into the row sum (qf32). const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) + const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -751,7 +1121,433 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); @@ -804,10 +1600,109 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -817,58 +1712,86 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); } + static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -1163,6 +2086,135 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vx2, + const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, @@ -1282,37 +2334,148 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + // Zero-out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -1321,11 +2484,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1341,23 +2507,32 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const r0_d = Q6_V_vdelta_VV(r0_d, expand); r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1373,30 +2548,40 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const r0_d = Q6_V_vdelta_VV(r0_d, expand); r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - // Zero-out unused scales + // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, +static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -1413,17 +2598,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (f32). + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -1433,13 +2620,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1447,9 +2640,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, vy_d = Q6_Vsf_equals_Vqf32(vy_d); // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); r0_d = Q6_V_vdelta_VV(r0_d, expand); @@ -1458,29 +2648,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_d = Q6_V_vdelta_VV(r1_d, expand); r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1488,9 +2695,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, vy_d = Q6_Vsf_equals_Vqf32(vy_d); // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); r0_d = Q6_V_vdelta_VV(r0_d, expand); @@ -1499,28 +2703,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_d = Q6_V_vdelta_VV(r1_d, expand); r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); } + static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -2138,7 +3360,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { @@ -2168,39 +3389,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + if (mmctx->vec_dot_4x1 != NULL) { + const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + + // Prefill spad with 4x src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 4); } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x4) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + } + + // Process leftovers + uint32_t ir0 = src0_end_row_x4; + if (ir0 + 2 <= src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 2); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + ir0 += 2; } - } + if (ir0 < src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + ir0 += 1; + } + } else { + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); @@ -2432,6 +3703,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // *** dynamic quant +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // --- Sum calculation --- + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements + // Sum 8 elements: + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + // Copy to stack to extract sums and vmaxes + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + __fp16 * y_d_half = (__fp16 *) y_d; + y_d_half[0] = d0; + y_d_half[1] = (float) sums[0] * d0; + y_d_half[2] = d1; + y_d_half[3] = (float) sums[8] * d1; + y_d_half[4] = d2; + y_d_half[5] = (float) sums[16] * d2; + y_d_half[6] = d3; + y_d_half[7] = (float) sums[24] * d3; +} + static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -2656,6 +4015,77 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales/sums + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + } + + // now copy the scales/sums into final location + hvx_copy_f16_ua(y_d, t_d, nb * 16); +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8_1x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -2751,24 +4181,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_IQ4_NL: mmctx->type = "iq4nlx4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; return 0; default: return -1; @@ -2894,8 +4335,13 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } @@ -2962,7 +4408,7 @@ int op_matmul(struct htp_ops_context * octx) { // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } @@ -3098,8 +4544,13 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); From 8c8f213daccd54f7a913034c68a885c11b851134 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 14:22:33 -0700 Subject: [PATCH 174/289] ggml-webgpu: remove legacy constants (llama/23672) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f6d17a073be..1846886db4e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -94,14 +94,6 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable -// default -#define WEBGPU_ROW_SPLIT_WG_SIZE 64 - -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to -// implementations so this can be removed, necessary only for get_rows right now -#define WEBGPU_MAX_WG_SIZE 288 - /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to @@ -631,7 +623,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, size_t size) { std::vector params = { (uint32_t) offset, (uint32_t) size, value }; std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -1366,7 +1358,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -3716,13 +3708,13 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[0].value = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; constants[1].key = "bytes_per_thread"; constants[1].value = ctx->capabilities.memset_bytes_per_thread; ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); From 7e843a80e1df28463a64a92ffd10410d5280ab3b Mon Sep 17 00:00:00 2001 From: ymcki <84055651+ymcki@users.noreply.github.com> Date: Thu, 28 May 2026 12:23:21 +0800 Subject: [PATCH 175/289] opencl: OP_GATED_DELTA_NET (llama/23312) * OP_GATED_DELTA_NET impl * add back lanes_per_column declaration * removed has_subgroup_arithmetic and has_subgroup_clustered_reduce * removed trailing spaces and fixes indentation. Hard coded subgroup size for Adreno and Intel. Return not supported when K>1 state snapshot * support for K>1 state snapshot * removed picky indent multiple of 4 fixes * removed return that won\'t be executed --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 345 ++++++++++++++++-- .../ggml-opencl/kernels/gated_delta_net.cl | 247 +++++++++++++ 3 files changed, 566 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gated_delta_net.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index f75d089b574..446fb727996 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -164,6 +164,7 @@ set(GGML_OPENCL_KERNELS sqr sqrt ssm_conv + gated_delta_net sub sum_rows cumsum diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 42286435bc6..6d6c3e8973d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -412,6 +412,7 @@ struct ggml_backend_opencl_context { size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; std::regex *opfilter = nullptr; // regex of ops to not claim @@ -634,6 +635,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; + // [size_idx][kda][tgpp] where size_idx: 0=S_V=16, 1=32, 2=64, 3=128; kda: 0 or 1. + // tgpp 0 = TG variant (COLS_PER_LANE_GROUP=1), tgpp 1 = prefill variant (COLS_PER_LANE_GROUP=4). + cl_kernel kernel_gated_delta_net_f32[4][2][2] = {}; + cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; @@ -837,16 +842,16 @@ static std::vector g_ggml_backend_opencl_devices; static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { - std::ifstream ifs(path); - if (!ifs) { - return ""; - } - std::string text; - ifs.seekg(0, std::ios::end); - text.resize(ifs.tellg()); - ifs.seekg(0, std::ios::beg); - ifs.read(&text[0], text.size()); - return text; + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; } static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { @@ -2463,12 +2468,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); if (backend_ctx->program_upscale) { - cl_int err_bilinear; - backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); - if (err_bilinear != CL_SUCCESS) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); backend_ctx->kernel_upscale_bilinear = nullptr; - } + } } else { backend_ctx->kernel_upscale_bilinear = nullptr; } @@ -2538,8 +2543,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } - // conv2d - { + // conv2d + { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "conv2d.cl.h" @@ -2597,6 +2602,86 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gated_delta_net: one kernel per (S_V, KDA, tgpp) triple. + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gated_delta_net.cl.h" + }; + #else + const std::string kernel_src = read_file("gated_delta_net.cl"); + #endif + + const int gdn_sizes[4] = { 16, 32, 64, 128 }; + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + for (int si = 0; si < 4; si++) { + const int S_V = gdn_sizes[si]; + + // MUST match the dispatcher heuristic in ggml_cl_gated_delta_net exactly. + int lanes_per_column; + if (S_V >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min(S_V, sg_size); + } + + // Round LANES_PER_COLUMN down until it is: + // * power-of-two + // * divides both S_V and sg_size + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (S_V % lanes_per_column) != 0 || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT((S_V % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const bool is_partial_reduce = (lanes_per_column != 1) && (lanes_per_column < sg_size); + int use_qcom_shuffle = 0; + if (is_partial_reduce) { + if (backend_ctx->has_qcom_subgroup_shuffle) { + use_qcom_shuffle = 1; + } + } + for (int kda = 0; kda < 2; kda++) { + for (int tgpp = 0; tgpp < 2; tgpp++) { + const int cpl = (tgpp == 0) ? 1 : 4; + const int spw = (tgpp == 0) ? 1 : 1; + + std::string opts = compile_opts; + opts += " -DS_V=" + std::to_string(S_V); + opts += " -DKDA=" + std::to_string(kda); + opts += " -DSUBGROUP_SIZE=" + std::to_string(sg_size); + opts += " -DLANES_PER_COLUMN=" + std::to_string(lanes_per_column); + opts += " -DCOLS_PER_LANE_GROUP=" + std::to_string(cpl); + opts += " -DUSE_QCOM_SUBGROUP_SHUFFLE=" + std::to_string(use_qcom_shuffle); + + // Since spw=1 is found to be optimal, SUBGROUPS_PER_WG > 1 code in + // the kernel is removed. If you want to experiment with spw > 1, + // Please remember to implement code to handle it. + opts += " -DSUBGROUPS_PER_WG=" + std::to_string(spw); + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), opts); + + CL_CHECK((backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp] = + clCreateKernel(prog, "kernel_gated_delta_net", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + } + } + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2827,7 +2912,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_1_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); #endif @@ -2866,7 +2951,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_iq4_nl_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); #endif @@ -2905,7 +2990,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q8_0_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); #endif @@ -2946,7 +3031,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_k_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); #endif @@ -3781,6 +3866,16 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // check support for qcom_subgroup_shuffle + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") != NULL) { + GGML_LOG_INFO("ggml_opencl: cl_khr_subgroups support: true\n"); + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; + } + } + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); @@ -4832,17 +4927,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_EXP: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_SOFTPLUS: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: return false; } @@ -4891,6 +4986,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_SSM_CONV: return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_GATED_DELTA_NET: + { + // Match the Vulkan backend: only F32 -> F32, S_v in {16, 32, 64, 128}. + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + const int64_t S_v = op->src[2]->ne[0]; + return S_v == 16 || S_v == 32 || S_v == 64 || S_v == 128; + } case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -10555,7 +10659,7 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; - if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { local_work_size_ptr = nullptr; } @@ -17052,6 +17156,185 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; + + GGML_ASSERT(src_q && src_q->extra); + GGML_ASSERT(src_k && src_k->extra); + GGML_ASSERT(src_v && src_v->extra); + GGML_ASSERT(src_g && src_g->extra); + GGML_ASSERT(src_beta && src_beta->extra); + GGML_ASSERT(src_state && src_state->extra); + + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) backend->context; + + const cl_uint S_v = (cl_uint) src_v->ne[0]; + const cl_uint H_v = (cl_uint) src_v->ne[1]; + const cl_uint n_tokens = (cl_uint) src_v->ne[2]; + const cl_uint n_seqs = (cl_uint) src_v->ne[3]; + const cl_uint K = (cl_uint) src_state->ne[1]; + + int si; + switch (S_v) { + case 16: si = 0; break; + case 32: si = 1; break; + case 64: si = 2; break; + case 128: si = 3; break; + default: + GGML_ASSERT(false && "ggml_cl_gated_delta_net: unsupported S_v"); + } + + const int kda = (src_g->ne[0] == (int64_t) S_v) ? 1 : 0; + + // TODO: Optimize when S_v!=128. Not necessary for now as Qwen3.5/6 are all S_v=128 + // token generation mode (tgpp=0): + // process 1 token at a time, so columns per lane (cpl) == 1 + // prompt processing mode (tgpp=1): + // cpl=4 to process 4 tokens for single-token. 4 is chosen for Adreno 750 as per + // work-item/thread has at most 128 registers. + // All Qwen3.5/6 models are S_v == 128, so LANES_PER_COLUMN == 8 + // such that ROWS_PER_LANE = 128/8 = 16 + // Variables in the kernel: + // k_reg, q_reg, g_exp are all 16 floats + // s_shard has cpl*ROWS_PER_LANE = 4*16 = 64 floats + // Total 112 registers used. + // subgroups_per_workgroup (spw) can be set to 1,2,4,8,16 for tg and 1,2,4 for pp + // for S_v=128. + // Empirically found that when spw=1, we get the best performance for both tg and pp + const int tgpp = (n_tokens == 1) ? 0 : 1; + const int cpl = (tgpp == 0) ? 1 : 4; + // spw needs adjustment when S_v != 128 + const int spw = (tgpp == 0) ? 1 : 1; + + cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp]; + GGML_ASSERT(kernel != nullptr); + + const cl_uint s_off = S_v * H_v * n_tokens * n_seqs; + + const cl_uint sq1 = (cl_uint)(src_q->nb[1] / sizeof(float)); + const cl_uint sq2 = (cl_uint)(src_q->nb[2] / sizeof(float)); + const cl_uint sq3 = (cl_uint)(src_q->nb[3] / sizeof(float)); + const cl_uint sv1 = (cl_uint)(src_v->nb[1] / sizeof(float)); + const cl_uint sv2 = (cl_uint)(src_v->nb[2] / sizeof(float)); + const cl_uint sv3 = (cl_uint)(src_v->nb[3] / sizeof(float)); + const cl_uint sb1 = (cl_uint)(src_beta->nb[1] / sizeof(float)); + const cl_uint sb2 = (cl_uint)(src_beta->nb[2] / sizeof(float)); + const cl_uint sb3 = (cl_uint)(src_beta->nb[3] / sizeof(float)); + + const cl_uint H_k = (cl_uint) src_q->ne[1]; + const cl_uint rq3 = (cl_uint)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float) S_v); + + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *) src_q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *) src_k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *) src_v->extra; + ggml_tensor_extra_cl * extra_g = (ggml_tensor_extra_cl *) src_g->extra; + ggml_tensor_extra_cl * extra_beta = (ggml_tensor_extra_cl *) src_beta->extra; + ggml_tensor_extra_cl * extra_state = (ggml_tensor_extra_cl *) src_state->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *) dst->extra; + + const cl_ulong off_q = extra_q->offset + src_q->view_offs; + const cl_ulong off_k = extra_k->offset + src_k->view_offs; + const cl_ulong off_v = extra_v->offset + src_v->view_offs; + const cl_ulong off_g = extra_g->offset + src_g->view_offs; + const cl_ulong off_beta = extra_beta->offset + src_beta->view_offs; + const cl_ulong off_state = extra_state->offset + src_state->view_offs; + const cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + int idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_q)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_g->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_g)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_beta->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_beta)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_state->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_state)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_tokens)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_seqs)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s_off)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &rq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &K)); + + // Subgroup size is 64 for Adreno and 32 for Intel + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + // For the subgroup-shuffle kernel, we can safely prefer 8 lanes/column for S_v>=128 + // For the subgroup-shuffle kernel: + // S_v >= 128 -> prefer 8 lanes/column (good occupancy & register pressure tradeoff) + // else -> min(S_v, subgroup_size) + int lanes_per_column; + if ((int)S_v >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min((int)S_v, sg_size); + } + + // Max workgroup size for Adreno 750 is 1024 + const int wg_size = sg_size * spw; + + // Ensure lanes_per_column is a power-of-two and divides both S_v and subgroup_size. + // (Required for lane-group shuffle-xor reduction correctness.) + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (((int)S_v % lanes_per_column) != 0) || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT(((int)S_v % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const int cols_per_wg = spw * (sg_size / lanes_per_column) * cpl; + GGML_ASSERT(cols_per_wg > 0); + GGML_ASSERT(((int)S_v % cols_per_wg) == 0); + + size_t global_work_size[3]; + size_t local_work_size[3]; + + global_work_size[0] = (size_t) H_v * (size_t) wg_size; + global_work_size[1] = (size_t) n_seqs; + global_work_size[2] = (size_t) S_v / (size_t) cols_per_wg; + + local_work_size[0] = (size_t) wg_size; + local_work_size[1] = 1; + local_work_size[2] = 1; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -17267,8 +17550,8 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_group_norm; break; - case GGML_OP_REPEAT: - if (!any_on_device) { + case GGML_OP_REPEAT: + if (!any_on_device) { return false; } func = ggml_cl_repeat; @@ -17297,6 +17580,14 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_ssm_conv; break; + case GGML_OP_GATED_DELTA_NET: + if (!any_on_device) { + return false; + } + // GDN has 6 source tensors, so it cannot use the standard + // (src0, src1, dst) func signature. Dispatch directly and return. + ggml_cl_gated_delta_net(backend, tensor); + return true; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl new file mode 100644 index 00000000000..d11192f5802 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifndef S_V +#define S_V 128 +#endif +#ifndef KDA +#define KDA 0 +#endif +#ifndef SUBGROUP_SIZE +#define SUBGROUP_SIZE 64 +#endif +#ifndef LANES_PER_COLUMN +#define LANES_PER_COLUMN 8 +#endif +#ifndef COLS_PER_LANE_GROUP +#define COLS_PER_LANE_GROUP 1 +#endif +#ifndef SUBGROUPS_PER_WG +#define SUBGROUPS_PER_WG 1 +#endif +#ifndef USE_QCOM_SUBGROUP_SHUFFLE +#define USE_QCOM_SUBGROUP_SHUFFLE 0 +#endif + +#define WG_SIZE (SUBGROUP_SIZE * SUBGROUPS_PER_WG) +#define LANE_GROUPS_PER_SG (SUBGROUP_SIZE / LANES_PER_COLUMN) +#define COLS_PER_SG (LANE_GROUPS_PER_SG * COLS_PER_LANE_GROUP) +#define COLS_PER_WG (SUBGROUPS_PER_WG * COLS_PER_SG) +#define ROWS_PER_LANE (S_V / LANES_PER_COLUMN) + +#if USE_QCOM_SUBGROUP_SHUFFLE +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#endif + +// XOR-based parallel sum +// This does a reduction across groups of LANES_PER_COLUMN +static inline float reduce_add_shmem(float partial, __local float * temp, uint lane) { +#if USE_QCOM_SUBGROUP_SHUFFLE + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + partial += qcom_sub_group_shuffle_xor(partial, s, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, partial); + } + return partial; +#else + temp[lane] = partial; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + float other = temp[lane ^ s]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + temp[lane] += other; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + } + const float result = temp[lane]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + return result; +#endif +} + +#define REDUCE_PARTIAL(partial, temp_ptr, lid) \ + ((LANES_PER_COLUMN == 1u) ? (partial) : reduce_add_shmem((partial), (temp_ptr), (lid))) + +// force compiler to optimize kernel for a specific fixed work-group size +__attribute__((reqd_work_group_size(WG_SIZE, 1, 1))) +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gated_delta_net( + global const char * q_buf, ulong off_q, + global const char * k_buf, ulong off_k, + global const char * v_buf, ulong off_v, + global const char * g_buf, ulong off_g, + global const char * beta_buf, ulong off_beta, + global const char * state_buf, ulong off_state, + global char * dst_buf, ulong off_dst, + uint H_v, + uint n_tokens, + uint n_seqs, + uint s_off, + uint sq1, uint sq2, uint sq3, + uint sv1, uint sv2, uint sv3, + uint sb1, uint sb2, uint sb3, + uint H_k, + uint rq3, + float scale, + uint K) { + + global const float * data_q = (global const float *)(q_buf + off_q); + global const float * data_k = (global const float *)(k_buf + off_k); + global const float * data_v = (global const float *)(v_buf + off_v); + global const float * data_g = (global const float *)(g_buf + off_g); + global const float * data_beta = (global const float *)(beta_buf + off_beta); + global const float * data_state = (global const float *)(state_buf + off_state); + global float * data_dst = (global float *)(dst_buf + off_dst); + + const uint head_id = get_group_id(0); + const uint seq_id = get_group_id(1); + const uint tid = (uint)get_local_id(0); + + const uint sg_id = get_sub_group_id(); // subgroup id + const uint sg_lid = get_sub_group_local_id(); // subgroup lane id + + const uint lane = sg_lid % LANES_PER_COLUMN; + const uint lane_group = sg_lid / LANES_PER_COLUMN; + const uint wg_col_base = get_group_id(2) * COLS_PER_WG; + const uint sg_col_base = wg_col_base + sg_id * COLS_PER_SG; + + const uint iq1 = head_id % H_k; // head index for Q and K + const uint iq3 = seq_id / rq3; // seq index for Q and K + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * K * H_v + head_id) * state_size; + const uint q_off_base = iq3 * sq3 + iq1 * sq1; + const uint v_off_base = seq_id * sv3 + head_id * sv1; + const uint gb_off_base = seq_id * sb3 + head_id * sb1; + const uint state_out_base = (seq_id * H_v + head_id) * state_size; + const uint state_size_per_snap = state_size * H_v * n_seqs; + + __local float reduce_temp[WG_SIZE]; + __local float * temp_ptr = reduce_temp + sg_id * SUBGROUP_SIZE; + + float s_shard[COLS_PER_LANE_GROUP][ROWS_PER_LANE]; + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[cg][r] = data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]; + } + } + + const int shift = (int)n_tokens - (int)K; + uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = q_off_base + t * sq2; + const uint k_off = q_off; + const uint v_off = v_off_base + t * sv2; + const uint gb_off = gb_off_base + t * sb2; + const float beta_val = data_beta[gb_off]; + + float k_reg[ROWS_PER_LANE]; + float q_reg[ROWS_PER_LANE]; +#if KDA + float g_exp[ROWS_PER_LANE]; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + g_exp[r] = exp(data_g[gb_off * S_V + i]); + } +#else + const float g_val = exp(data_g[gb_off]); + + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + } +#endif + + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + float v_val = data_v[v_off + col]; + + float kv_shard = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; + kv_shard += gs * k_reg[r]; +#else + kv_shard += s_shard[cg][r] * k_reg[r]; +#endif + } + +#if !KDA + kv_shard *= g_val; // Applied once instead of ROWS_PER_LANE times +#endif + + const float kv_col = REDUCE_PARTIAL(kv_shard, temp_ptr, sg_lid); + + const float delta_col = (v_val - kv_col) * beta_val; + + float attn_partial = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; +#else + float gs = g_val * s_shard[cg][r]; +#endif + s_shard[cg][r] = gs + k_reg[r] * delta_col; + attn_partial += s_shard[cg][r] * q_reg[r]; + } + const float attn_col = REDUCE_PARTIAL(attn_partial, temp_ptr, sg_lid); + + if (lane == 0) { + data_dst[attn_off + col] = attn_col * scale; + } + } + attn_off += S_V * H_v; + + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + const uint slot_base = s_off + (uint)target_slot * state_size_per_snap + state_out_base; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } + } + } + + if (K == 1u) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } +} From d284e1c3aa307e56cc43a81152c1cebea46e29cb Mon Sep 17 00:00:00 2001 From: ymcki <84055651+ymcki@users.noreply.github.com> Date: Thu, 28 May 2026 14:05:25 +0800 Subject: [PATCH 176/289] Hexagon: OP_GATED_DELTA_NET K>1 support (llama/23531) * K>1 state snapshot support * removed picky indent multiple of 4 fixes --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +-- .../ggml-hexagon/htp/gated-delta-net-ops.c | 32 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e8a4a740c1..3af7aff7028 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2537,6 +2537,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; + const int64_t K = state->ne[1]; if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2549,10 +2550,10 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { + if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { return false; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 2e84badc9b7..c4d08bb21c4 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -586,6 +586,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -606,6 +607,10 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_sums[4] __attribute__((aligned(128))); + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const int64_t shift = (int64_t) n_tokens - (int64_t) K; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -615,8 +620,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; memcpy(s_out, s_in, gctx->state_bytes); float * s_work = s_out; @@ -689,6 +694,16 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } } + if (K > 1) { + const int64_t target_slot = (int64_t) t - shift; + if (target_slot >= 0 && target_slot < (int64_t) K) { + float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + if (curr_state_o != s_work) { + memcpy(curr_state_o, s_work, gctx->state_bytes); + } + } + } + attn_data += (uint64_t) S_v * H; } } @@ -709,6 +724,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -736,6 +752,9 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; } + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -745,8 +764,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; float * s_work; if (spad) { @@ -901,6 +920,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -913,10 +933,10 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { return HTP_STATUS_NO_SUPPORT; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return HTP_STATUS_NO_SUPPORT; } From 8e403258767fa1e0946006cee470064f0375f0ba Mon Sep 17 00:00:00 2001 From: Martin Klacer Date: Thu, 28 May 2026 08:04:21 +0100 Subject: [PATCH 177/289] ggml: fixed Arm SVE usage bug in vec.h, vec.cpp (llama/22841) * Updated vec.h/vec.cpp code to accumulate to F32 rather than F16 Change-Id: I0cb789347f2bf60ffaf9047319f727e788c825f8 Signed-off-by: Martin Klacer Co-authored-by: Milos Puzovic --- ggml/src/ggml-cpu/vec.cpp | 90 +++++++++------------- ggml/src/ggml-cpu/vec.h | 158 +++++++++++++++++--------------------- 2 files changed, 107 insertions(+), 141 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index d0e4001338a..67b6b05cac8 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; //get vector length - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - const int np= (n & ~(ggml_f16_step - 1)); - svfloat16_t sum1 = svdup_n_f16(0.0f); - svfloat16_t sum2 = svdup_n_f16(0.0f); - svfloat16_t sum3 = svdup_n_f16(0.0f); - svfloat16_t sum4 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); - - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); - - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 8 * ggml_f16_epr; + const int np = n - (n % ggml_f16_step); + const int np2 = n - (n % ggml_f16_epr); + + svfloat32_t sum1_lo = svdup_n_f32(0.0f); + svfloat32_t sum1_hi = svdup_n_f32(0.0f); + svfloat32_t sum2_lo = svdup_n_f32(0.0f); + svfloat32_t sum2_hi = svdup_n_f32(0.0f); + svfloat32_t sum3_lo = svdup_n_f32(0.0f); + svfloat32_t sum3_hi = svdup_n_f32(0.0f); + svfloat32_t sum4_lo = svdup_n_f32(0.0f); + svfloat32_t sum4_hi = svdup_n_f32(0.0f); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); - - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); - - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + for (int i = 0; i < np; i += ggml_f16_step) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3)); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7)); } - const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + for (int i = np; i < np2; i += ggml_f16_epr) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0)); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2)); + const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum1 = svmad_f16_x(pg, hx, hy, sum1); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry); } - GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi); + sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo); + sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi); + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi); + + sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi); #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) int vl = __riscv_vsetvlmax_e32m2(); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index bcd68da9aa9..5de9cb5b7e0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -14,6 +14,35 @@ // floating point type used to accumulate sums typedef double ggml_float; +#if defined(__ARM_FEATURE_SVE) +inline static void ggml_sve_f16_fma_widened( + svfloat32_t * acc_lo, + svfloat32_t * acc_hi, + svfloat16_t x, + svfloat16_t y) { +#if defined(__ARM_FEATURE_SVE2) + *acc_lo = svmlalb_f32(*acc_lo, x, y); + *acc_hi = svmlalt_f32(*acc_hi, x, y); +#else + // Plain SVE fallback path if SVE2 instructions not available + svfloat16_t x_even = svtrn1_f16(x, x); + svfloat16_t x_odd = svtrn2_f16(x, x); + + svfloat16_t y_even = svtrn1_f16(y, y); + svfloat16_t y_odd = svtrn2_f16(y, y); + + svbool_t pg = svptrue_b32(); + + *acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even)); + *acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd)); +#endif +} + +inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) { + return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi)); +} +#endif + #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - int np = (n & ~(ggml_f16_step - 1)); - - svfloat16_t sum_00 = svdup_n_f16(0.0f); - svfloat16_t sum_01 = svdup_n_f16(0.0f); - svfloat16_t sum_02 = svdup_n_f16(0.0f); - svfloat16_t sum_03 = svdup_n_f16(0.0f); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 2 * ggml_f16_epr; + int np = n - (n % ggml_f16_step); + int np2 = n - (n % ggml_f16_epr); - svfloat16_t sum_10 = svdup_n_f16(0.0f); - svfloat16_t sum_11 = svdup_n_f16(0.0f); - svfloat16_t sum_12 = svdup_n_f16(0.0f); - svfloat16_t sum_13 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f); for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); - - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); - ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0); - ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); - - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - - ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); - ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); - - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - - ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); - - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); - ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); - - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - - ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); - - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); - ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); - - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - - ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); - - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); - ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); - - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - - ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); - - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); - ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1); + ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + for (int i = np; i < np2; i += ggml_f16_epr) { + const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); - sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); - rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); - sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); - svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); - sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay); } - GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); - GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + + svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo); + svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi); + svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo); + svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi); + sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi); + sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi); np = n; #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) From 60e420ff6ac28ae5bc5af42b4a77bc98dca760e6 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 28 May 2026 10:55:42 +0200 Subject: [PATCH 178/289] cuda : fix KQ mask offset integer overflow in fattn MMA kernel (llama/23610) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 4871b90df86..3c8b6eaaf24 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i); } } else if constexpr (oob_check) { #pragma unroll @@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; - tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f); } } } else if constexpr (nbatch_fa < 2*warp_size) { @@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = threadIdx.x % (warp_size/cols_per_warp); - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i); } } else { #pragma unroll @@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i); } } } From 5db94bac041884f804c49ae98a57a738f83b9c0c Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Thu, 28 May 2026 18:46:07 +0800 Subject: [PATCH 179/289] vulkan: Fix memory logger unsafe iterator access (llama/23667) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index fb07282ef76..b8ac4a9c26c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2095,9 +2095,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); std::string type = device ? "device" : "host"; auto it = allocations.find(buf->buffer); - total_device -= device ? it->second : 0; - total_host -= device ? 0 : it->second; if (it != allocations.end()) { + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); allocations.erase(it); } else { From 816c3029bc8cd046d4e2726ba2f1bbf99b8adc8f Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Thu, 28 May 2026 18:48:34 +0800 Subject: [PATCH 180/289] vulkan: fix wrong index variable in inner loop (llama/23665) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b8ac4a9c26c..238ee822397 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7233,7 +7233,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; for (uint64_t i0 = 0; i0 < ne0; i0++) { - slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 }); } } } From b896e91f18ec245f1415fe5d18a77e766197985e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 28 May 2026 06:18:43 -0500 Subject: [PATCH 181/289] vulkan: fast path for walsh-hadamard transform (llama/23687) * vulkan: fast path for walsh-hadamard transform * disable for intel due to segfault --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 91 +++++++++++++++++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 69 ++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 161 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 238ee822397..c9f906d7930 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -860,6 +860,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_fwht_f32[4]; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_small_f32; vk_pipeline pipeline_cumsum_multipass1_f32; @@ -1150,6 +1151,13 @@ struct vk_op_push_constants { float param4; }; +struct vk_op_fwht_push_constants { + uint32_t n_rows; + uint32_t src_offset; + uint32_t dst_offset; + float scale; +}; + struct vk_op_count_experts_push_constants { uint32_t ne00; uint32_t ne01; @@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + struct ggml_backend_vk_buffer_context { vk_device_ref device; vk_buffer dev_buffer; @@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // Intel Arc B390 was observed segfaulting with this shader. + if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) { + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + if (device->subgroup_size <= n) { + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size); + } + ++idx; + } + } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); @@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } +static int ggml_vk_fwht_pipeline_idx(int64_t n) { + switch (n) { + case 64: return 0; + case 128: return 1; + case 256: return 2; + case 512: return 3; + default: return -1; + } +} + +static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) { + if (ctx->num_additional_fused_ops != 0) { + return false; + } + + if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) { + return false; + } + + const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]); + if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) { + return false; + } + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src1)) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(dst)); + + return true; +} + +static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) { + const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]); + vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx]; + + const uint32_t rows_per_workgroup = 4; + const uint32_t n_rows = (uint32_t)ggml_nrows(src); + const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup); + const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true); + const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); + + vk_op_fwht_push_constants pc = { + n_rows, + 0, + 0, + 1.0f / std::sqrt((float)src->ne[0]), + }; + init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 }); +} + static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; @@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c m_offset += cur_M_size; } + } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) { + ggml_vk_fwht(ctx, subctx, src1, dst); } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp new file mode 100644 index 00000000000..72059d4afc2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -0,0 +1,69 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint N = 128; + +layout(push_constant) uniform parameter +{ + uint n_rows; + uint src_offset; + uint dst_offset; + float scale; +}; + +layout(binding = 0, std430) readonly buffer A { float data_a[]; }; +layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; + +const uint EL_W = N / WARP_SIZE; + +void main() { + const uint lane = gl_SubgroupInvocationID; + for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; + row < n_rows; + row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row_offset = row * N; + + float reg[EL_W]; + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + } + + [[unroll]] + for (uint h = 1; h < WARP_SIZE; h <<= 1) { + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float val2 = subgroupShuffleXor(val, h); + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + + [[unroll]] + for (uint h = WARP_SIZE; h < N; h <<= 1) { + const uint step = h / WARP_SIZE; + [[unroll]] + for (uint j = 0; j < EL_W; j += 2 * step) { + [[unroll]] + for (uint k = 0; k < step; ++k) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 24b9d25f733..fa9b938e4f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -934,6 +934,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fwht_f32", "fwht.comp", {}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From 1b241b879c4687d7ff4b3af1a14cb8e491a70d2d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 28 May 2026 04:49:11 -0700 Subject: [PATCH 182/289] hexagon: minor refresh for HMX FA and MM (llama/23796) * hex-fa: clean up qf32/fp32 handling and stride handling * hex-fa: fix corner case fp NAN issues that were cause bad output from gemma4 on v79 * hex-fa: vectorize leftover handling * hex-fa: avoid HVX fallback during token gen HMX has more FP16 compute capacity * hmx-mm: remove dead code * hmx-mm: use fastdiv in x4x2 dequant * hmx-mm: sandwich dequant and scatter to improve perf * hmx-mm: fixed rebase conflicts * hmx-mm: further improve weight dequant by doing early type dispatch and precomputing fastdiv * hmx-mm: an even earlier dispatch for per-type dequant * hmx-mm: dequant linear types like q4_0 and q4_1 without the LUTs This is a bit faster than LUT. * hex-cmake: one more tweak for lto --------- Co-authored-by: Trivikram Reddy --- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 3 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 157 +++--- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 3 - ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 493 ++++++++++-------- 4 files changed, 370 insertions(+), 286 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index d7927261a85..ff3fc0804e3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE - hmx-queue.c hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-queue.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-queue.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d95df6ac9d5..1bd8c1407de 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -22,6 +22,16 @@ // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) @@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)); + rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); hvx_vec_store_u(r, 4, rsum); } @@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } - HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); - HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); - HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); - HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); + HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)); + HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)); + HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)); + HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)); HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; return hvx_vec_reduce_sum_f32x4(rsum0123); @@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors const size_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector sums; // initialize at j = 0 + HVX_Vector sums = Q6_V_vzero(); const size_t stride_x_4 = stride_x * 4; for (uint32_t j = 0; j < VLEN_FP32; j += 4) { HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); @@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, x += stride_x_4; } - sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); - return Q6_Vsf_equals_Vqf32(sums); + return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums); } // MAD: y (F32) += x (F16) * s (F16) @@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * uint32_t i = 0; #pragma unroll(4) for (; i < nvec; ++i) { - vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs); } if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs)); } } @@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Process in sub-blocks of 32 (VLEN_FP32) HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); - for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { + for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); - scores = Q6_Vsf_equals_Vqf32(scores); + scores = HVX_OP_MUL_F32(scores, logit_cap); } // 3. Mask if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); - HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); - scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); - scores = Q6_Vsf_equals_Vqf32(scores); + + // Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79. + // Clamp -INFINITY to the max negative fp16 finite value (-65504.0f). + HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00); + HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF); + HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf); + m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16); + + #if __HVX_ARCH__ >= 79 + HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_vadd_VsfVsf(add_val, scores); + #else + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores)); + #endif + } + + // Mask out invalid lanes for leftover handling + uint32_t valid_lanes = current_block_size - ic; + if (valid_lanes < VLEN_FP32) { + HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane + scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY)); } sb_scores[iv] = scores; @@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec); HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); - for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = sb_scores[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec); + HVX_Vector P = hvx_vec_exp_f32(scores_shifted); - p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P); // 5. Accumulate V __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); + float __attribute__((aligned(128))) P_arr[VLEN_FP32]; + hvx_vec_store_a(P_arr, 128, P); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + const uint32_t cur_ic = ic2 + j; + if (cur_ic >= current_block_size) { + break; + } + + if (cur_ic + 1 == current_block_size) { + // Odd leftover, process single row + if (P_arr[j] != 0.0f) { + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV); + } + break; + } + + // Avoid NaN * 0.0 = NaN for uninitialized V cache rows. + // Check the f32 values to safely avoid strict aliasing violations. + if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) { + continue; + } + + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); - } - - if (ic < current_block_size) { - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); - - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); - } - - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } - - const float Mold = M; - __fp16 vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } - - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - - hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); - } - - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); + S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec); } // Issue DMA for next+1 block (if exists) @@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const int i2 = iq2; const int i3 = iq3; - // dst is permuted - uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + // dst is permuted: [DV, n_heads, n_tokens, n_seq] + // head stride is nb[1], token stride is nb[2], batch stride is nb[3] + uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3]; if (dst->type == HTP_TYPE_F32) { hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); @@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + // HMX path: head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index a496f6289ae..f132c08500d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { if (DK % 32 != 0 || DV % 32 != 0) { return HTP_STATUS_NO_SUPPORT; } - if (neq1 < 32) { - return HTP_STATUS_NO_SUPPORT; - } // GQA factor const uint32_t n_kv_heads = k->ne[2]; diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index ab5fd73380b..083d125882d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -16,6 +16,7 @@ #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "worker-pool.h" #include "hvx-utils.h" @@ -187,45 +188,44 @@ static int hmx_compute_chunks(size_t vtcm_total, // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT - v_quants = Q6_Vb_vshuff_Vb(v_quants); - // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. - // _nomatch retains the previous destination-register value for colliding - // indices, but the C intrinsic doesn't model the implicit read so the - // compiler may allocate a register containing garbage/NaN. - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. +// full HVX vector width. // Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - // Load all 128 packed bytes (4 contiguous 32-byte groups) + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT - v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] - HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); - // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -233,13 +233,12 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ - v_hi /* group2 already in [0:63] */ }; + HVX_Vector_x2 r = { v_lo, v_hi }; return r; } static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_dm = hvx_vmemu(scale_offset); @@ -248,9 +247,9 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32 HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); + + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); } @@ -258,16 +257,18 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32 static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); @@ -287,6 +288,45 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( return r; } +// LUT-based dequantizers for non-linear IQ4_NL format. +static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale = hvx_vmemu(scales_4); + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); @@ -374,122 +414,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * return r; } +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; + int n_k_tiles; + struct fastdiv_values n_k_tiles_div; +} x4x2_dequantize_state_t; + // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. // Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. // Output: vtcm_dst in tile-major FP16 layout. -static void dequantize_x4x2_weight_to_fp16_tiles_task( - __fp16 *restrict vtcm_dst, - const uint8_t *restrict vtcm_src, - int n_cols, int k_block, - size_t row_stride, int weight_type, - int start_tile, int end_tile) { - - const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); - const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1); - const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; - - const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : - (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : - (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : - hvx_vmem(q4_0_to_fp16_lut); - // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. - // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 - // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) - - unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index - unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index - for (unsigned t = start_tile; t < end_tile; ) { - if (kt >= n_k_tiles) { kt = 0; ct++; } - - // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 - bool upper = (sub_blk_base >= 4); - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; - unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); - unsigned scale_off = qrow_size + blk_idx * dblk_size - + sub_blk_base * scale_step; - - __fp16 *tile_bases[4]; - for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } - - HVX_Vector v_off = v_scat_base; - - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - - if (is_q4_1) { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); +#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ +static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ + const x4x2_dequantize_state_t *state, \ + int start_tile, int end_tile) { \ + \ + const int n_k_tiles = state->n_k_tiles; \ + const int qrow_size = (unsigned)state->k_block / 2; \ + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ + const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ + \ + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ + \ + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ + \ + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ + \ + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ + unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk_base >= 4); \ + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ + \ + __fp16 *tile_bases[4]; \ + for (unsigned g = 0; g < 4; g++) { \ + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ + } \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + \ + HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + \ + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ + t += 4; kt += 4; \ + continue; \ + } \ + \ + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ + { \ + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk >= 4); \ + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ + ? dequantize_x4x2_##helper_prefix##_group_hvx( \ + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ + : Q6_V_vzero(); \ + \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + (void) *(volatile HVX_Vector *)(tile_base); \ + } \ + ++t; ++kt; \ + } \ + \ + if (start_tile < end_tile) { \ + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ + } \ +} \ + \ +static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ + } \ +} - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); +DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) +DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) +DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } else { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; +static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } - t += 4; kt += 4; - continue; - } + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- - if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + // Batch-4 fast path for MXFP4 + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; __fp16 * tile_bases[4]; for (int g = 0; g < 4; g++) { - tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; } HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector_x4 dv0, dv1; dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { @@ -510,58 +604,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( (void) *(volatile HVX_Vector *) (tile_bases[g]); } - t += 4; + t += 4; kt += 4; continue; } - // --- Single-tile fallback --- - __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - - if (is_q4) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; - bool upper = (sub_blk >= 4); - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; - unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); - unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step; - - HVX_Vector v_off = v_scat_base; // reset to column 0 - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - if (is_q4_1) { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } else { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } - (void) *(volatile HVX_Vector *)(tile_base); - } else if (weight_type == HTP_TYPE_MXFP4) { + // Single-tile fallback + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_MXFP4x4x2; int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk >= 4); @@ -573,15 +622,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); HVX_Vector v1; - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); } else { @@ -594,23 +642,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } (void) *(volatile HVX_Vector *) (tile_base); - } else { - // Q8_0 + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); + } +} + +static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_Q8_0x4x2; int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - HVX_Vector v_off = v_scat_base; // reset to column 0 + HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); + HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); @@ -622,50 +706,31 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( ++t; ++kt; } - // Drain HVX scatter write buffer: a vmem load on the same HW thread retires - // all pending scatter entries to VTCM. Without this, the main thread's HMX - // reads may see stale data because atomic_fetch_sub (release) only orders - // regular stores, not the HVX scatter buffer. if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); } } -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; -} x4x2_dequantize_state_t; - -static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { +static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { int start = task_id * state->n_tiles_per_task; int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - - dequantize_x4x2_weight_to_fp16_tiles_task( - state->dst, state->src, state->n_cols, state->k_block, - state->row_stride, state->weight_type, start, end); + dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); } } static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type) { + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); @@ -680,8 +745,10 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.k_block = k_block; state.row_stride = row_stride; state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); } // --- End x4x2 dequantizers --- @@ -978,6 +1045,20 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * return -1; } + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; @@ -1070,7 +1151,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * { // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); // A1: issue DMA for weight chunk 1 const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); @@ -1089,7 +1170,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); } } @@ -1131,7 +1212,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) if (i + 2 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); } } } From 04795e6272a74123486dfccfd0e62ecf816ba178 Mon Sep 17 00:00:00 2001 From: Jaden_Mach <88880593+jadenmach2@users.noreply.github.com> Date: Thu, 28 May 2026 08:50:25 -0400 Subject: [PATCH 183/289] CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware (llama/23227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8) to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs substantially by quant family because the per-row GEMV cost is dominated by dequantisation, not the dot-product itself: K-quants pay a heavier super-block decode and so MMQ wins sooner; legacy and IQ quants have lean decode and stay ahead until the batch fully populates an MFMA tile. This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool, mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant thresholds on amd_mfma_available(cc): Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%) Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%) others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged) Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold for A/B testing. Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct, llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps. Full table in PR description. Selected pp512 throughput (tok/s, ub=8): Q4_K_S: 559 -> 940 (+68%) Q5_K_S: 503 -> 884 (+76%) Q3_K_S: 629 -> 879 (+40%) Q2_K : 615 -> 809 (+32%) Q6_K : 582 -> 776 (+33%) Selected pp512 throughput (tok/s, ub=4): Q4_K_S: 444 -> 480 (+ 8%) Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ) IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ) * CUDA: address review — inline MMVQ batch table, drop env hatch & doc block * tune kernel selection logic for CDNA1 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 ++ ggml/src/ggml-cuda/mmvq.cu | 47 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cuh | 2 ++ 3 files changed, 51 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 23d1c069248..dc3e8fd6265 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2570,6 +2570,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { @@ -2578,6 +2579,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 13b8b855282..873ff05a074 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -271,6 +271,53 @@ int get_mmvq_mmid_max_batch(ggml_type type, int cc) { return MMVQ_MAX_BATCH_SIZE; } +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA1(cc)) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return ne11 <= 7; + case GGML_TYPE_Q5_1: + return ne11 <= 7; + case GGML_TYPE_Q8_0: + return ne11 <= 6; + case GGML_TYPE_Q2_K: + return ne11 <= 4; + case GGML_TYPE_Q3_K: + return ne11 <= 3; + case GGML_TYPE_Q4_K: + return ne11 <= 2; + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 4; + case GGML_TYPE_IQ1_S: + return ne11 <= 5; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + return ne11 <= 6; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + switch (type) { // tuned for CDNA2 + case GGML_TYPE_Q2_K: + return ne11 <= 5; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 5; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + return ne11 <= MMVQ_MAX_BATCH_SIZE; +} + // Device constexpr: returns the max batch size for the current arch+type at compile time. template static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 6bf0a8e8677..5605bf7a4e6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,8 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11); + // Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, // based on the quantization type and GPU architecture (compute capability). int get_mmvq_mmid_max_batch(ggml_type type, int cc); From 4e8af441e5f5ec8b91e193a598929cce489374ed Mon Sep 17 00:00:00 2001 From: redfox <59549776+yaohengxu@users.noreply.github.com> Date: Thu, 28 May 2026 20:51:14 +0800 Subject: [PATCH 184/289] =?UTF-8?q?mmvq=20Optim:=20add=20MMVQ=5FPARAMETERS?= =?UTF-8?q?=5FTURING(mmvq=5Fparameter=5Ftable=5Fid)=20for=20=E2=80=A6=20(#?= =?UTF-8?q?23729)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mmvq Optim: add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for SM75 TURING * avoid a mismatch for JIT compilation of Turing device code for Ampere or newer Co-authored-by: Johannes Gäßler --------- Co-authored-by: Copilot Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmvq.cu | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 873ff05a074..ecb6fdedadd 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -63,6 +63,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_TURING, MMVQ_PARAMETERS_GCN, MMVQ_PARAMETERS_RDNA2, MMVQ_PARAMETERS_RDNA3_0, @@ -78,6 +79,8 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE + return MMVQ_PARAMETERS_TURING; #else return MMVQ_PARAMETERS_GENERIC; #endif @@ -96,6 +99,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { return MMVQ_PARAMETERS_GCN; } + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) { + return MMVQ_PARAMETERS_TURING; + } return MMVQ_PARAMETERS_GENERIC; } @@ -417,11 +423,38 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d } return 1; } + if (table_id == MMVQ_PARAMETERS_TURING) { + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 2; + default: + return 4; + } + } + switch (ncols_dst) { + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } return 1; } static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) { switch (ncols_dst) { case 1: return small_k ? nwarps : 1; From e1faa7cb4d7b2c4f185fbf3fef04fc616d871fec Mon Sep 17 00:00:00 2001 From: fl0rianr <226492742+fl0rianr@users.noreply.github.com> Date: Thu, 28 May 2026 15:01:14 +0200 Subject: [PATCH 185/289] ggml: auto apply iGPU flag CUDA/HIP if integrated device (llama/23007) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index dc3e8fd6265..18aaa098398 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4994,8 +4994,14 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + return prop.integrated + ? GGML_BACKEND_DEVICE_TYPE_IGPU + : GGML_BACKEND_DEVICE_TYPE_GPU; } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { From 94922ce12cd88a5449e9451c34332b667e6a1d14 Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 28 May 2026 11:05:42 -0700 Subject: [PATCH 186/289] opencl: move backend info printing into its own function (llama/23702) * opencl: move backend info print into its own function * opencl: move new log line * opencl: fix for non adreno path --- ggml/src/ggml-opencl/ggml-opencl.cpp | 155 +++++++++++++++------------ 1 file changed, 86 insertions(+), 69 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 6d6c3e8973d..751ec6116c0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -379,6 +379,8 @@ struct ggml_backend_opencl_device_context { GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + std::regex *opfilter = nullptr; // regex of ops to not claim + std::string opfilter_str; // regex string for opfilter size_t global_mem_size = 0; }; @@ -415,8 +417,6 @@ struct ggml_backend_opencl_context { bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; - std::regex *opfilter = nullptr; // regex of ops to not claim - bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; @@ -428,6 +428,8 @@ struct ggml_backend_opencl_context { size_t image2d_max_width; size_t image2d_max_height; + cl_device_svm_capabilities svm_caps; + cl_context context; cl_command_queue queue; @@ -3731,6 +3733,68 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +static void ggml_opencl_print_backend_info(ggml_backend_opencl_device_context * dev_ctx) { + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->backend_ctx); + + auto * backend_ctx = dev_ctx->backend_ctx; + + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", + backend_ctx->driver_version.c_str()); + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", + backend_ctx->fp16_support ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", + backend_ctx->alignment); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", + backend_ctx->global_mem_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", + backend_ctx->max_alloc_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", + backend_ctx->image_max_buffer_size); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", + backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", + backend_ctx->max_workgroup_size); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); + if (backend_ctx->adreno_xmem_gemm_enabled) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM enabled (temporary weight prepack)\n"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + + if (dev_ctx->opfilter) { + // for information only, the actual regex object is created in ggml_opencl_is_device_supported + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", dev_ctx->opfilter_str.c_str()); + } +} + // check if device should be accepted static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { GGML_ASSERT(dev); @@ -3799,6 +3863,13 @@ static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { } clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + dev_ctx->opfilter_str = str_opfilter; + dev_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + } + return true; } @@ -3850,15 +3921,12 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { char *driver_version = (char *)alloca(driver_version_str_size + 1); clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); driver_version[driver_version_str_size] = '\0'; - GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); backend_ctx->has_vector_subgroup_broadcast = (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); - GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -3867,18 +3935,12 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated // check support for qcom_subgroup_shuffle - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") != NULL) { - GGML_LOG_INFO("ggml_opencl: cl_khr_subgroups support: true\n"); - if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { - backend_ctx->has_qcom_subgroup_shuffle = true; - } + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; } - GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", - backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; - GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; @@ -3887,35 +3949,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; - GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); backend_ctx->global_mem_size = dev_ctx->global_mem_size; - GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); - GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL); - GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); - - clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); - - // Check SVM. - cl_device_svm_capabilities svm_caps; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); - GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", - svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &backend_ctx->svm_caps, 0)); if (opencl_c_version.major >= 3) { // Assume it is not available for 3.0, since it is optional in 3.0. @@ -3931,36 +3973,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->non_uniform_workgroups = true; } - // Print out configurations -#ifdef GGML_OPENCL_SOA_Q - GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); -#endif // GGML_OPENCL_SOA_Q - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); -#endif // GGML_OPENCL_USE_ADRENO_KERNELS - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use Adreno xmem GEMM backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr) { - GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM %s\n", - backend_ctx->adreno_xmem_gemm_enabled ? - "enabled (temporary weight prepack)" : "requested but unsupported by this driver"); - } -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // determine whether to use large buffer for Adreno backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (backend_ctx->adreno_use_large_buffer) { - if (!backend_ctx->adreno_has_large_buffer) { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); - backend_ctx->adreno_use_large_buffer = false; - } else { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); - } - } cl_int err; @@ -4010,12 +4031,6 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; - const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); - if (str_opfilter) { - backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); - GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter); - } - dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -4825,7 +4840,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; // reject ops that match the opfilter regex - if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) { + if (dev_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *dev_ctx->opfilter)) { return false; } @@ -7823,6 +7838,8 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co /* .context = */ backend_ctx, }; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_opencl_print_backend_info(dev_ctx); return backend; GGML_UNUSED(params); From 442be1789d750994b8afaad8533e16e46730606e Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 28 May 2026 14:05:54 -0700 Subject: [PATCH 187/289] hexagon: basic/generic op fusion support and RMS_NORM+MUL fusion (llama/23835) Updating infra to enable op fusion and using RMS_NORM+MUL as the use-case. --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 143 +++++++-------- ggml/src/ggml-hexagon/htp-opnode.h | 241 +++++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 194 +++++++++++++++++++- ggml/src/ggml-hexagon/op-desc.h | 153 ---------------- 6 files changed, 497 insertions(+), 236 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp-opnode.h delete mode 100644 ggml/src/ggml-hexagon/op-desc.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3af7aff7028..48ded82e83c 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -39,7 +39,7 @@ #include "ggml-hexagon.h" #include "ggml-impl.h" #include "ggml-quants.h" -#include "op-desc.h" +#include "htp-opnode.h" #include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -102,23 +102,23 @@ static const char * status_to_str(uint32_t status) { // ** debug helpers -static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) { +static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_opnode & node, const uint32_t req_flags) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(htp_opformat(htp_opnode{const_cast(op), {}, HTP_OP_INVALID})); GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } -static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, +static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node, uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; @@ -129,15 +129,16 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; struct ggml_hexagon_opqueue; +struct htp_opnode; struct ggml_hexagon_session { std::string name; @@ -167,7 +168,7 @@ struct ggml_hexagon_session { void allocate(int dev_id) noexcept(false); void release() noexcept(true); - void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void enqueue_op(const htp_opnode & node); void flush(bool all = true); void flush_pending(bool all = false); @@ -1782,12 +1783,10 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; -// Backend session implementation - struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; - std::vector ops; // pointers to original ops + std::vector ops; // htp_opnode of ops std::vector h_bufs; // htp buffer descriptors std::vector h_tens; // htp tensor descriptors @@ -1919,7 +1918,7 @@ struct ggml_hexagon_opbatch { return ti; } - bool fit_op(const struct ggml_tensor *t) const { + bool fit_op(const htp_opnode & node) const { if (n_ops >= n_ops_max ) return false; // check how much extras we will need @@ -1939,10 +1938,10 @@ struct ggml_hexagon_opbatch { } }; - for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { - fit_tensor(t->src[i]); + for (const auto * src : node.get_inputs()) { + fit_tensor(src); } - fit_tensor(t); + fit_tensor(node.dst()); if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1952,29 +1951,30 @@ struct ggml_hexagon_opbatch { } // assumes that fit_op() was called first and returned true - void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + void add_op(const htp_opnode & node) { // Add new op unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); - ops[n] = t; + ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &t->op_params, sizeof(t->op_params)); - o.opcode = opcode; + memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + o.opcode = node.opcode; o.flags = 0; if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; } - o.dst = add_tensor(t); + o.dst = add_tensor(node.dst()); } }; @@ -1983,7 +1983,7 @@ struct ggml_hexagon_opqueue { ggml_hexagon_shared_buffer *shm_buf; size_t shm_blk_size; - using opvec = std::vector; + using opvec = std::vector; std::queue done; // completed batch ids std::vector op_cache; // per batch op cache @@ -2182,11 +2182,11 @@ void ggml_hexagon_session::flush_batch() { } } -void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { - if (!op_batch->fit_op(op)) { +void ggml_hexagon_session::enqueue_op(const htp_opnode & node) { + if (!op_batch->fit_op(node)) { flush_batch(); } - op_batch->add_op(opcode, op); + op_batch->add_op(node); } // Flush HTP response queue i.e wait for all outstanding requests to complete @@ -3179,10 +3179,43 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); + std::vector nodes; + nodes.reserve(graph->n_nodes); + + // Fusion for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { - sess->enqueue_op(op_remap_to_htp(n), n); + if (!op_is_compute(n)) { + continue; + } + + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + htp_opnode node = { + /*.node =*/ n, + /*.fused =*/ {}, + /*.opcode =*/ HTP_OP_INVALID + }; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + node.add_fused(next_node); + node.opcode = HTP_OP_RMS_NORM_MUL; + i++; // skip the fused MUL node + } + } + + if (node.opcode == HTP_OP_INVALID) { + node.opcode = op_remap_to_htp(n); + } + + nodes.push_back(std::move(node)); + } + + // Queue and execute + if (opt_opstage & HTP_OPSTAGE_QUEUE) { + for (const auto & node : nodes) { + sess->enqueue_op(node); } } @@ -3201,51 +3234,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { sess->flush(); } -struct node_info { - ggml_tensor * node; - - std::vector fused; - - ggml_op op() const { - return node->op; - } - - const ggml_tensor * dst() const { - return fused.empty() ? node : fused.back(); - } - - const ggml_tensor * src0() const { - return node->src[0]; - } - - const ggml_tensor * src1() const { - return node->src[1]; - } - - bool is_empty() const { - return ggml_op_is_empty(node->op); - } - - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - - bool stackable() const { - switch (this->op()) { - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return ggml_is_quantized(this->src0()->type); - default: - return false; - } - } - - bool same_input(const node_info& n) const { - return n.src1() == this->src1(); - } -}; - -static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector & nodes) { +static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector & nodes) { const int n = nodes.size(); std::vector res; @@ -3299,14 +3288,14 @@ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgr enum ggml_op ops[MAX_FUSE]; - std::vector nodes; + std::vector nodes; nodes.reserve(gf->n_nodes); // fuse nodes: // we don't want to make reorders that break fusing, so we first pack all fusable tensors // and perform the reorder over the fused nodes. after the reorder is done, we unfuse for (int i = 0; i < n; i++) { - node_info node = { + htp_opnode node = { /*.node =*/gf->nodes[i], /*.fused =*/{}, }; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h new file mode 100644 index 00000000000..14b232240b4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -0,0 +1,241 @@ +#ifndef HTP_OPNODE_H +#define HTP_OPNODE_H + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" + +#include +#include +#include +#include "htp-ops.h" + +struct htp_opnode { + ggml_tensor * node = nullptr; + + std::vector fused; + + htp_op_code opcode = HTP_OP_INVALID; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + const ggml_tensor * src0() const { + return node->src[0]; + } + + const ggml_tensor * src1() const { + return node->src[1]; + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } + + bool stackable() const { + switch (this->op()) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return ggml_is_quantized(this->src0()->type); + default: + return false; + } + } + + bool same_input(const htp_opnode& n) const { + return n.src1() == this->src1(); + } + + std::vector get_inputs() const { + std::vector inputs; + std::vector outputs; + outputs.push_back(node); + for (const auto * f : fused) { + outputs.push_back(f); + } + + auto contains = [&](const std::vector & vec, const ggml_tensor * t) { + for (const auto * x : vec) { + if (x == t) return true; + } + return false; + }; + + auto add_input = [&](const ggml_tensor * t) { + if (t && !contains(outputs, t) && !contains(inputs, t)) { + inputs.push_back(t); + } + }; + + for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { + add_input(node->src[i]); + } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { + add_input(f->src[i]); + } + } + return inputs; + } + + std::string op_name() const { + if (fused.empty()) { + return ggml_op_desc(node); + } + std::string name = ggml_op_desc(node); + for (const auto * f : fused) { + name += "+"; + name += ggml_op_desc(f); + } + return name; + } +}; + +struct htp_opformat { + char strides[64 * GGML_MAX_SRC]; + char dims[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } + } + + void format_op_dims(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_dims(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_dims(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_dims(self, node.dst()); + p += sprintf(p, "%s", self); + } + + int format_tensor_strides(char * str, const struct ggml_tensor * t) { + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + } + } + + void format_op_strides(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_strides(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_strides(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_strides(self, node.dst()); + p += sprintf(p, "%s", self); + } + + void format_op_types(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + } + + const char * tensor_buff_name(const struct ggml_tensor * t) { + if (t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; + } + + void format_op_buffs(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", tensor_buff_name(inputs[0])); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", tensor_buff_name(node.dst())); + } + + void format_op_names(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0]->name); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i]->name); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", node.dst()->name); + } + + void format(const htp_opnode & node) { + format_op_dims(dims, node); + format_op_strides(strides, node); + format_op_types(types, node); + format_op_buffs(buffs, node); + format_op_names(names, node); + } + + htp_opformat() {} + htp_opformat(const htp_opnode & node) { format(node); } +}; + +#endif // HTP_OPNODE_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index aadc77235ba..fa85bf4ca0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -58,6 +58,7 @@ enum htp_op_code { HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, HTP_OP_RMS_NORM, + HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, HTP_OP_UNARY_SIGMOID, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 7dd90ac7d7f..623008be4e2 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -537,6 +537,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_NORM: case HTP_OP_RMS_NORM: + case HTP_OP_RMS_NORM_MUL: case HTP_OP_SCALE: case HTP_OP_SQR: case HTP_OP_SQRT: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 7d0431d8ba8..770a6673211 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -23,21 +23,26 @@ struct htp_unary_context { // Precomputed values const uint8_t * data_src0; + const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL uint8_t * data_dst; size_t src0_data_row_size; // actual data bytes per row + size_t src1_data_row_size; size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; + size_t src1_row_size_aligned; size_t dst_row_size_aligned; size_t src0_spad_half_size; + size_t src1_spad_half_size; size_t dst_spad_half_size; uint32_t block; uint32_t src0_nrows; uint32_t src0_nrows_per_thread; uint32_t nc; + bool broadcast_weight; }; // Convert flat row index to DDR byte offset using the tensor's actual strides. @@ -158,6 +163,71 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src, + const uint8_t * restrict weight, + uint8_t * restrict dst, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (const HVX_Vector *) src; + const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + // Scale and multiply + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]); + v_dst[i] = Q6_Vsf_equals_Vqf32(result); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]); + HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v); + } +} + static void hvx_fast_norm_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, @@ -269,6 +339,27 @@ static void rms_norm_f32(const float * restrict src, } } +static void rms_norm_mul_f32(const float * restrict src, + const float * restrict weight, + float * restrict dst, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + const size_t weight_row_size, + int32_t * op_params, + bool broadcast_weight) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon); + } +} + static void norm_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -598,12 +689,15 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * t1 = HAP_perf_get_qtimer_count(); const uint8_t * restrict data_src = uctx->data_src0; + const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t src1_spad_half_size = uctx->src1_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride @@ -624,6 +718,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; + // If weight is broadcasted, load it once per thread at the beginning of execution + if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) { + dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1); + dma_queue_flush(dma_queue); + } + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); @@ -636,6 +736,14 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue_push(dma_queue, dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), src0_row_size_aligned, nb01, src0_data_row_size, block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), + uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + } + ir += block_size; } @@ -644,6 +752,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = NULL; + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + src1_spad = (float *) dma_queue_pop(dma_queue).dst; + } // Process block in VTCM switch (htp_op) { @@ -653,6 +765,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_RMS_NORM_MUL: + { + const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad; + rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight); + } + break; case HTP_OP_SCALE: scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -700,9 +818,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * if (pref_ir < src0_end_row) { const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); - dma_queue_push(dma_queue, - dma_make_ptr(src0_spad, data_src + src0_pref_off), - src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad, data_src1 + src1_pref_off), + uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + } } } ir += block_size; @@ -732,6 +857,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; + case HTP_OP_RMS_NORM_MUL: + op_type = "rmsnorm-mul-f32"; + break; case HTP_OP_SCALE: op_type = "scale-f32"; break; @@ -777,12 +905,44 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); + size_t src1_data_row_size = 0; + size_t src1_row_size_aligned = 0; + bool broadcast_weight = false; + const struct htp_tensor * src1 = NULL; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + src1 = octx->src[1]; + src1_data_row_size = src1->ne[0] * sizeof(float); + src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN); + broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1); + } + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size // Double buffering requires 2x size per buffer - size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); - size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + size_t spad_size_per_row = 0; + size_t vtcm_row_per_thread = 0; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + size_t available_vtcm = octx->ctx->vtcm_size; + size_t src1_spad_total = n_threads * src1_row_size_aligned; + if (available_vtcm > src1_spad_total) { + available_vtcm -= src1_spad_total; + } else { + available_vtcm = 0; + } + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row); + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row); + } + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + } // Make sure the reserved vtcm size is sufficient if (vtcm_row_per_thread == 0) { @@ -797,8 +957,25 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + octx->src1_spad.size_per_thread = src1_row_size_aligned; + } else { + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2; + } + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } else { + octx->src1_spad.size = 0; + octx->src1_spad.size_per_thread = 0; + } + octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + } else { + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], @@ -811,19 +988,24 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, + .data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL, .data_dst = (uint8_t *)dst->data, .src0_data_row_size = src0_data_row_size, + .src1_data_row_size = src1_data_row_size, .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, + .src1_row_size_aligned = src1_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0, .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, .nc = src0->ne[0], + .broadcast_weight = broadcast_weight, }; worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); diff --git a/ggml/src/ggml-hexagon/op-desc.h b/ggml/src/ggml-hexagon/op-desc.h deleted file mode 100644 index a1e8ddd8b97..00000000000 --- a/ggml/src/ggml-hexagon/op-desc.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef OP_DESC_H -#define OP_DESC_H - -#define GGML_COMMON_IMPL_CPP -#include "ggml-backend-impl.h" -#include "ggml-common.h" - -#include -#include - -struct op_desc { - char strides[64 * GGML_MAX_SRC]; - char dims[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - int format_tensor_dims(char * str, const struct ggml_tensor * t) { - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); - } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - } - } - - void format_op_dims(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_dims(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_dims(self, t); - - p += sprintf(p, "%s", self); - } - - int format_tensor_strides(char * str, const struct ggml_tensor * t) { - const char * c = ggml_is_contiguous(t) ? "" : "!"; - - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); - } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); - } - } - - void format_op_strides(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_strides(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_strides(self, t); - - p += sprintf(p, "%s", self); - } - - void format_op_types(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", ggml_type_name(t->type)); - } - - const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { - return ggml_backend_buffer_name(t->buffer); - } - return "NONE"; - } - - void format_op_buffs(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", tensor_buff_name(t->src[0])); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(t->src[i])); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", tensor_buff_name(t)); - } - - void format_op_names(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", t->src[0]->name); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", t->src[i]->name); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", t->name); - } - - void format(const ggml_tensor * op) { - format_op_dims(dims, op); - format_op_strides(strides, op); - format_op_types(types, op); - format_op_buffs(buffs, op); - format_op_names(names, op); - } - - op_desc() {} - op_desc(const ggml_tensor * op) { format(op); } -}; - -#endif // OP_DESC_H From f1b687da28a6e28beb2a2e7ed2d74f554eb279be Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Fri, 29 May 2026 03:30:24 +0000 Subject: [PATCH 188/289] meta : Add missing `buffer` set in allreduce fallback !COMPUTE clear (llama/23480) Without this at least the vulkan backend will skip the `* 0` for !COMPUTE tensors, causing corrupt output. --- ggml/src/ggml-backend-meta.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index d0d64523b4a..48b2027fac3 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -2076,6 +2076,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, node_zero->src[0] = node; ggml_set_op_params_f32(node_zero, 0, 0.0f); node_zero->data = node->data; + node_zero->buffer = node->buffer; node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; step_cgraphs[j] = get_cgraph_aux(); From e90501e179632071cd7bba5cf5f05ec9991e64ff Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Fri, 29 May 2026 06:46:10 +0200 Subject: [PATCH 189/289] cuda : disables launch_fattn PDL enrollment due to compiler bug (llama/23825) --- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index debcb6e5447..d650b5fbd0f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1153,8 +1153,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); - ggml_cuda_kernel_launch(fattn_kernel, launch_params, + // disabled PDL enrollment for now due to a compiler bug. + fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, From cc65eb1816f780fd8478c58894f45b4c160e5ffc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:43:15 +0300 Subject: [PATCH 190/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index a4f87b2b9ae..6aed494381c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -e705c5fed490514458bdd2eaddc43bd098fcce9b +5fbba2f28a17545214650298fd729563475004ca From 5828fba79f0c00f4cd7c7c205824b72664ac79d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:44:28 +0300 Subject: [PATCH 191/289] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 6 +- examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 20 +++- examples/talk-llama/llama-chat.h | 1 + examples/talk-llama/llama-model.cpp | 3 + examples/talk-llama/llama-model.h | 2 +- examples/talk-llama/llama-vocab.cpp | 14 ++- examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/models/mistral3.cpp | 12 +- examples/talk-llama/models/models.h | 13 +++ examples/talk-llama/models/talkie.cpp | 149 ++++++++++++++++++++++++ 11 files changed, 213 insertions(+), 9 deletions(-) create mode 100644 examples/talk-llama/models/talkie.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index c9eead18aa3..e95ba6daac1 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -133,6 +133,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_TALKIE, "talkie" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -767,8 +768,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super - {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 89cf16cc37c..7c1dcc4d6c2 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -137,6 +137,7 @@ enum llm_arch { LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_TALKIE, LLM_ARCH_UNKNOWN, }; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index f10397747b0..6d822ec62d6 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -62,6 +62,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, + { "granite-4.1", LLM_CHAT_TEMPLATE_GRANITE_4_1 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -194,7 +195,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { if (tmpl_contains("") || tmpl_contains("")) { - return LLM_CHAT_TEMPLATE_GRANITE_4_0; + if (tmpl_contains("g4_default_system_message")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_4_1; } return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { @@ -651,6 +655,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_1) { + // IBM Granite 4.1 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index ea6540c0be7..dc37f919a96 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -41,6 +41,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE_3_X, LLM_CHAT_TEMPLATE_GRANITE_4_0, + LLM_CHAT_TEMPLATE_GRANITE_4_1, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0d21b2a53c5..0c3e03a61dc 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -44,6 +44,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_llama_embed(params); case LLM_ARCH_MAINCODER: return new llama_model_maincoder(params); + case LLM_ARCH_TALKIE: + return new llama_model_talkie(params); case LLM_ARCH_DECI: return new llama_model_deci(params); case LLM_ARCH_BAICHUAN: @@ -2353,6 +2355,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_TALKIE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 398a0aa725c..b797b8966ac 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -488,7 +488,7 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias - // gemma4 layer output scale + // gemma4 layer output scale, reused for talkie embedding skip scale struct ggml_tensor * out_scale = nullptr; struct llama_layer_posnet posnet; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a5cf148b268..473becade82 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -511,6 +511,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; byte_encode = false; break; + case LLAMA_VOCAB_PRE_TYPE_MINICPM5: + regex_exprs = { + // original regex from tokenizer.json (openbmb/MiniCPM5-1B) + "\\p{N}{1,3}", + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -2039,6 +2047,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if (tokenizer_pre == "default") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "minicpm5") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINICPM5; + ignore_merges = true; } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || @@ -2196,7 +2207,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "gpt-4o" || tokenizer_pre == "llama4" || - tokenizer_pre == "kanana2") { + tokenizer_pre == "kanana2" || + tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; } else if ( diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 8b040b912e2..8ab77594284 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -60,6 +60,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, }; struct LLM_KV; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 4e6ebef82cb..1ac5a95ccdc 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -177,9 +177,9 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -200,7 +200,11 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa LLM_FFN_SILU, true, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 7e551eb965b..db228865d5d 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base { }; +struct llama_model_talkie : public llama_model_base { + llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deci : public llama_model_base { llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp new file mode 100644 index 00000000000..1258eeb19b6 --- /dev/null +++ b/examples/talk-llama/models/talkie.cpp @@ -0,0 +1,149 @@ +#include "models.h" + +void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_talkie::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // no k gain + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0); + } +} + +std::unique_ptr llama_model_talkie::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * embd_skip = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * inp_skip = embd_skip; + + cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + // reference applies qknorm after rope + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale); + cb(skip, "embd_skip", il); + + cur = ggml_add(ctx0, cur, skip); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} From 92fc3f2a58bb6c518aef3bc8ddbe4c84e75a79b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:46:12 +0300 Subject: [PATCH 192/289] ggml : bump version to 0.13.1 (ggml/1523) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index f542f18b6d4..dc8899b46ef 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 13) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From f24588a272ae8e23280d9c220536437164e6ed28 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:46:42 +0300 Subject: [PATCH 193/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6aed494381c..538ef80bc7a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -5fbba2f28a17545214650298fd729563475004ca +1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3 From f39cc7128295ff5c67bbedb73161bed549f96e96 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:44:07 +0300 Subject: [PATCH 194/289] common : re-implement `ffmpeg-transcode.cpp` + clarify ffmpeg usage (#3846) * examples : remove ffmpeg-transcode.cpp * examples : implement ffmpeg-transcode.cpp Assisted-by: llama.cpp:local pi * common : switch from WHISPER_FFMPEG -> WHISPER_COMMON_FFMPEG --- CMakeLists.txt | 3 +- README.md | 7 +- examples/CMakeLists.txt | 4 +- examples/common-whisper.cpp | 84 +++--- examples/ffmpeg-transcode.cpp | 553 +++++++++++++--------------------- tests/CMakeLists.txt | 2 +- 6 files changed, 271 insertions(+), 382 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2200673d0a3..35c8674725f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,7 @@ option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) + option(WHISPER_COMMON_FFMPEG "whisper: examples link with ffmpeg libs in order to decode more audio formats" OFF) endif() option(WHISPER_COREML "whisper: enable Core ML framework" OFF) @@ -121,6 +121,7 @@ whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE) +whisper_option_depr(WARNING WHISPER_FFMPEG WHISPER_COMMON_FFMPEG) if (GGML_CUDA AND NOT MSVC) #GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets diff --git a/README.md b/README.md index 050a35be21c..d1680e99bfc 100644 --- a/README.md +++ b/README.md @@ -425,9 +425,10 @@ cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21" cmake --build build -j --config Release ``` -## FFmpeg support (Linux only) +## FFmpeg support (examples only) -If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration. +By default, the examples in this repo use the [miniaudio](https://github.com/mackron/miniaudio) library to decode audio files. +Some of the examples also can use FFmpeg for decoding and broader format support. To enable that, build with `WHISPER_COMMON_FFMPEG`. First, you need to install required libraries: @@ -442,7 +443,7 @@ sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-dev Then you can build the project as follows: ```bash -cmake -B build -D WHISPER_FFMPEG=yes +cmake -B build -D WHISPER_COMMON_FFMPEG=yes cmake --build build ``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..0bb54cec489 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,7 +20,7 @@ set(TARGET common) unset(COMMON_EXTRA_LIBS) -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) # As of cmake 3.27, there is no official cmake support for FindFFmpeg. # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE @@ -39,7 +39,7 @@ if (WHISPER_FFMPEG) message(STATUS "Found avformat ${AVFORMAT_VERSION}") include_directories(${FFMPEG_INCLUDE_DIRS}) - add_compile_definitions(WHISPER_FFMPEG) + add_compile_definitions(WHISPER_COMMON_FFMPEG) list(APPEND COMMON_EXTRA_LIBS ${FFMPEG_LIBRARIES}) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index d29166b50d8..8cdd2320c17 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -34,8 +34,8 @@ #include #include -#ifdef WHISPER_FFMPEG -// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support +#ifdef WHISPER_COMMON_FFMPEG +// as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); #endif @@ -75,7 +75,7 @@ static bool read_audio_from_decoder(ma_decoder & decoder, std::vector & p return true; } -bool read_audio_data(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { +bool read_audio_data(const std::string & fname, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { std::vector audio_data; // used for pipe input from stdin or ffmpeg decoding output ma_result result; @@ -96,53 +96,67 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); if (fname == "-") { - #ifdef _WIN32 - _setmode(_fileno(stdin), _O_BINARY); - #endif - - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - audio_data.insert(audio_data.end(), buf, buf + n); - } - - result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); +#ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); +#endif + + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + audio_data.insert(audio_data.end(), buf, buf + n); + } + + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - return false; - } + fprintf(stderr, "%s: failed to open audio data from stdin (%s)\n", __func__, ma_result_description(result)); + return false; + } decoder.initialized = true; - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); - } - else { - result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); - if (result == MA_SUCCESS) { - decoder.initialized = true; + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); + } else { + fprintf(stderr, "%s: reading audio data from '%s' ...\n", __func__, fname.c_str()); + + // first try miniaudio. if it fails (or skipped) - try ffmpeg + { + const char * skip = getenv("WHISPER_COMMON_MINIAUDIO_SKIP"); + if (!skip || strlen(skip) == 0 || strcmp(skip, "0") == 0) { + fprintf(stderr, "%s: trying to decode with miniaudio\n", __func__); + + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } + } else { + fprintf(stderr, "%s: skipping miniaudio\n", __func__); + } } -#if defined(WHISPER_FFMPEG) + +#if defined(WHISPER_COMMON_FFMPEG) if (!decoder.initialized) { + fprintf(stderr, "%s: trying to decode with ffmpeg\n", __func__); + if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); + fprintf(stderr, "%s: failed to ffmpeg decode\n", __func__); return false; } result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); + fprintf(stderr, "%s: failed to read audio data as wav (%s)\n", __func__, ma_result_description(result)); return false; } decoder.initialized = true; } -#else - if (!decoder.initialized) { - fprintf(stderr, "error: failed to read audio data from (%s)\n", fname.c_str()); - return false; - } #endif + + if (!decoder.initialized) { + fprintf(stderr, "%s: failed to read audio data\n", __func__); + return false; + } } return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index 1fae58a4ffa..dc57fe74596 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,368 +1,241 @@ -/* SPDX-License-Identifier: GPL-2.0 */ +#ifdef WHISPER_COMMON_FFMPEG -/* - * transcode.c - convert audio file to WAVE - * - * Copyright (C) 2019 Andrew Clayton - * Copyright (C) 2024 William Tambellini - */ +#include "whisper.h" -// Just for conveninent C++ API -#include #include - -// C -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include extern "C" { -#include -#include #include +#include #include } -typedef uint64_t u64; -typedef int64_t s64; -typedef uint32_t u32; -typedef int32_t s32; -typedef uint16_t u16; -typedef int16_t s16; -typedef uint8_t u8; -typedef int8_t s8; - -#define WAVE_SAMPLE_RATE 16000 -#define AVIO_CTX_BUF_SZ 4096 - -static const char* ffmpegLog = getenv("FFMPEG_LOG"); -// Todo: add __FILE__ __LINE__ -#define LOG(...) \ - do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 - -/* - * WAVE file header based on definition from - * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f - * - * We must ensure this structure doesn't have any holes or - * padding so we can just map it straight to the WAVE data. - */ -struct wave_hdr { - /* RIFF Header: "RIFF" */ - char riff_header[4]; - /* size of audio data + sizeof(struct wave_hdr) - 8 */ - int wav_size; - /* "WAVE" */ - char wav_header[4]; - - /* Format Header */ - /* "fmt " (includes trailing space) */ - char fmt_header[4]; - /* Should be 16 for PCM */ - int fmt_chunk_size; - /* Should be 1 for PCM. 3 for IEEE Float */ - s16 audio_format; - s16 num_channels; - int sample_rate; - /* - * Number of bytes per second - * sample_rate * num_channels * bit_depth/8 - */ - int byte_rate; - /* num_channels * bytes per sample */ - s16 sample_alignment; - /* bits per sample */ - s16 bit_depth; - - /* Data Header */ - /* "data" */ - char data_header[4]; - /* - * size of audio - * number of samples * num_channels * bit_depth/8 - */ - int data_bytes; -} __attribute__((__packed__)); - -struct audio_buffer { - u8 *ptr; - int size; /* size left in the buffer */ -}; - -static void set_wave_hdr(wave_hdr& wh, size_t size) { - memcpy(&wh.riff_header, "RIFF", 4); - wh.wav_size = size + sizeof(struct wave_hdr) - 8; - memcpy(&wh.wav_header, "WAVE", 4); - memcpy(&wh.fmt_header, "fmt ", 4); - wh.fmt_chunk_size = 16; - wh.audio_format = 1; - wh.num_channels = 1; - wh.sample_rate = WAVE_SAMPLE_RATE; - wh.sample_alignment = 2; - wh.bit_depth = 16; - wh.byte_rate = wh.sample_rate * wh.sample_alignment; - memcpy(&wh.data_header, "data", 4); - wh.data_bytes = size; +// Write a minimal WAV header into the output buffer. +// Returns the number of bytes written (44 for a standard PCM WAV header). +static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, int bits_per_sample, uint32_t data_size) { + // RIFF header + memcpy(buf, "RIFF", 4); + uint32_t chunk_size = 36 + data_size; + memcpy(buf + 4, &chunk_size, 4); + memcpy(buf + 8, "WAVE", 4); + + // fmt subchunk + memcpy(buf + 12, "fmt ", 4); + uint32_t subchunk1_size = 16; + memcpy(buf + 16, &subchunk1_size, 4); + uint16_t audio_format = 1; // PCM + memcpy(buf + 20, &audio_format, 2); + memcpy(buf + 22, &num_channels, 2); + memcpy(buf + 24, &sample_rate, 4); + + int bytes_per_sample = (bits_per_sample / 8) * num_channels; + int byte_rate = sample_rate * bytes_per_sample; + memcpy(buf + 28, &byte_rate, 4); + memcpy(buf + 32, &bytes_per_sample, 2); + memcpy(buf + 34, &bits_per_sample, 2); + + // data subchunk + memcpy(buf + 36, "data", 4); + memcpy(buf + 40, &data_size, 4); + + return 44; } -static void write_wave_hdr(int fd, size_t size) { - struct wave_hdr wh; - set_wave_hdr(wh, size); - write(fd, &wh, sizeof(struct wave_hdr)); -} +bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data) { + { + const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); + if (verbose && strcmp(verbose, "2") == 0) { + av_log_set_level(AV_LOG_DEBUG); + } else if (verbose && strcmp(verbose, "1") == 0) { + av_log_set_level(AV_LOG_VERBOSE); + } else { + av_log_set_level(AV_LOG_WARNING); + } + } -static int map_file(int fd, u8 **ptr, size_t *size) -{ - struct stat sb; + AVFormatContext * fmt_ctx = nullptr; + if (avformat_open_input(&fmt_ctx, ifname.c_str(), nullptr, nullptr) != 0) { + fprintf(stderr, "error: failed to open input file '%s'\n", ifname.c_str()); + return true; + } - fstat(fd, &sb); - *size = sb.st_size; + if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) { + fprintf(stderr, "error: failed to find stream information\n"); + avformat_close_input(&fmt_ctx); + return true; + } - *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); - if (*ptr == MAP_FAILED) { - perror("mmap"); - return -1; - } + // Find the first audio stream + int audio_stream_idx = -1; + for (unsigned int i = 0; i < fmt_ctx->nb_streams; i++) { + if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + audio_stream_idx = i; + break; + } + } - return 0; -} + if (audio_stream_idx == -1) { + fprintf(stderr, "error: failed to find an audio stream in '%s'\n", ifname.c_str()); + avformat_close_input(&fmt_ctx); + return true; + } -static int read_packet(void *opaque, u8 *buf, int buf_size) -{ - struct audio_buffer *audio_buf = (audio_buffer*)opaque; + AVStream * audio_stream = fmt_ctx->streams[audio_stream_idx]; - buf_size = FFMIN(buf_size, audio_buf->size); + // Open the decoder + const AVCodec * codec = avcodec_find_decoder(audio_stream->codecpar->codec_id); + if (!codec) { + fprintf(stderr, "error: failed to find decoder for codec id %d\n", audio_stream->codecpar->codec_id); + avformat_close_input(&fmt_ctx); + return true; + } - /* copy internal buffer data to buf */ - memcpy(buf, audio_buf->ptr, buf_size); - audio_buf->ptr += buf_size; - audio_buf->size -= buf_size; + AVCodecContext * codec_ctx = avcodec_alloc_context3(codec); + if (!codec_ctx) { + fprintf(stderr, "error: failed to allocate codec context\n"); + avformat_close_input(&fmt_ctx); + return true; + } - return buf_size; -} + if (avcodec_parameters_to_context(codec_ctx, audio_stream->codecpar) < 0) { + fprintf(stderr, "error: failed to copy codec parameters to context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, - AVFrame *frame, s16 **data, int *size, bool flush) -{ - int nr_samples; - s64 delay; - u8 *buffer; - - delay = swr_get_delay(swr, codec->sample_rate); - nr_samples = av_rescale_rnd(delay + frame->nb_samples, - WAVE_SAMPLE_RATE, codec->sample_rate, - AV_ROUND_UP); - av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); - - /* - * !flush is used to check if we are flushing any remaining - * conversion buffers... - */ - nr_samples = swr_convert(swr, &buffer, nr_samples, - !flush ? (const u8 **)frame->data : NULL, - !flush ? frame->nb_samples : 0); - - *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); - memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); - *size += nr_samples; - av_freep(&buffer); -} + if (avcodec_open2(codec_ctx, codec, nullptr) < 0) { + fprintf(stderr, "error: failed to open codec\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static bool is_audio_stream(const AVStream *stream) -{ - if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) - return true; + // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz + const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; + const int out_sample_rate = WHISPER_SAMPLE_RATE; - return false; -} + AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; -// Return non zero on error, 0 on success -// audio_buffer: input memory -// data: decoded output audio data (wav file) -// size: size of output data -static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) -{ - LOG("decode_audio: input size: %d\n", audio_buf->size); - AVFormatContext *fmt_ctx; - AVIOContext *avio_ctx; - AVStream *stream; - AVCodecContext *codec; - AVPacket *packet; - AVFrame *frame; - struct SwrContext *swr; - u8 *avio_ctx_buffer; - unsigned int i; - int stream_index = -1; - int err; - const size_t errbuffsize = 1024; - char errbuff[errbuffsize]; - - fmt_ctx = avformat_alloc_context(); - avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); - LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); - avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); - fmt_ctx->pb = avio_ctx; - - // open the input stream and read header - err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); - if (err) { - LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); - return err; - } - - err = avformat_find_stream_info(fmt_ctx, NULL); - if (err < 0) { - LOG("Could not retrieve stream info from audio buffer: %d\n", err); - return err; - } - - for (i = 0; i < fmt_ctx->nb_streams; i++) { - if (is_audio_stream(fmt_ctx->streams[i])) { - stream_index = i; - break; - } - } - - if (stream_index == -1) { - LOG("Could not retrieve audio stream from buffer\n"); - return -1; - } - - stream = fmt_ctx->streams[stream_index]; - codec = avcodec_alloc_context3( - avcodec_find_decoder(stream->codecpar->codec_id)); - avcodec_parameters_to_context(codec, stream->codecpar); - err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), - NULL); - if (err) { - LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); - return err; - } - - /* prepare resampler */ - swr = swr_alloc(); - -#if LIBAVCODEC_VERSION_MAJOR > 60 - AVChannelLayout in_ch_layout = codec->ch_layout; - AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; - - /* Set the source audio layout as-is */ - av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - - /* Convert it into 16khz Mono */ - av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#else - av_opt_set_int(swr, "in_channel_count", codec->channels, 0); - av_opt_set_int(swr, "out_channel_count", 1, 0); - av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); - av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#endif - - swr_init(swr); - if (!swr_is_initialized(swr)) { - LOG("Resampler has not been properly initialized\n"); - return -1; - } - - packet=av_packet_alloc(); - if (!packet) { - LOG("Error allocating the packet\n"); - return -1; - } - frame = av_frame_alloc(); - if (!frame) { - LOG("Error allocating the frame\n"); - return -1; - } - - /* iterate through frames */ - *data = NULL; - *size = 0; - while (av_read_frame(fmt_ctx, packet) >= 0) { - avcodec_send_packet(codec, packet); - - err = avcodec_receive_frame(codec, frame); - if (err == AVERROR(EAGAIN)) - continue; - - convert_frame(swr, codec, frame, data, size, false); - } - /* Flush any remaining conversion buffers... */ - convert_frame(swr, codec, frame, data, size, true); - - av_packet_free(&packet); - av_frame_free(&frame); - swr_free(&swr); - //avio_context_free(); // todo? - avcodec_free_context(&codec); - avformat_close_input(&fmt_ctx); - avformat_free_context(fmt_ctx); - - if (avio_ctx) { - av_freep(&avio_ctx->buffer); - av_freep(&avio_ctx); - } - - return 0; -} + SwrContext * swr_ctx = nullptr; + if (swr_alloc_set_opts2(&swr_ctx, &out_ch_layout, out_sample_fmt, out_sample_rate, + &codec_ctx->ch_layout, codec_ctx->sample_fmt, codec_ctx->sample_rate, + 0, nullptr) < 0) { + fprintf(stderr, "error: failed to allocate swr context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -// in mem decoding/conversion/resampling: -// ifname: input file path -// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav -// return 0 on success -int ffmpeg_decode_audio(const std::string &ifname, std::vector& owav_data) { - LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); - int ifd = open(ifname.c_str(), O_RDONLY); - if (ifd == -1) { - fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); - return -1; + if (swr_init(swr_ctx) < 0) { + fprintf(stderr, "error: failed to initialize swr context\n"); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; } - u8 *ibuf = NULL; - size_t ibuf_size; - int err = map_file(ifd, &ibuf, &ibuf_size); - if (err) { - LOG("Couldn't map input file %s\n", ifname.c_str()); - return err; + + // Decode and resample + AVPacket * packet = av_packet_alloc(); + AVFrame * frame = av_frame_alloc(); + + // Buffer to collect resampled output + std::vector pcm_data; + + // Max output samples per swr_convert call + const int max_out_samples = 16 * 1024; + std::vector out_buffer(max_out_samples); + + while (av_read_frame(fmt_ctx, packet) >= 0) { + if (packet->stream_index != audio_stream_idx) { + av_packet_unref(packet); + continue; + } + + int ret = avcodec_send_packet(codec_ctx, packet); + av_packet_unref(packet); + + if (ret < 0) { + continue; + } + + while (ret >= 0) { + ret = avcodec_receive_frame(codec_ctx, frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + } + if (ret < 0) { + break; + } + + // Resample + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } + } } - LOG("Mapped input file: %s size: %d\n", ibuf, (int) ibuf_size); - struct audio_buffer inaudio_buf; - inaudio_buf.ptr = ibuf; - inaudio_buf.size = ibuf_size; - - s16 *odata=NULL; - int osize=0; - - err = decode_audio(&inaudio_buf, &odata, &osize); - LOG("decode_audio returned %d \n", err); - if (err != 0) { - LOG("decode_audio failed\n"); - return err; + + // Flush the decoder + avcodec_send_packet(codec_ctx, nullptr); + while (avcodec_receive_frame(codec_ctx, frame) >= 0) { + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } } - LOG("decode_audio output size: %d\n", osize); - - wave_hdr wh; - const size_t outdatasize = osize * sizeof(s16); - set_wave_hdr(wh, outdatasize); - owav_data.resize(sizeof(wave_hdr) + outdatasize); - // header: - memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); - // the data: - memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); - - return 0; + + // Flush the resampler + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + int flush_samples = swr_convert(swr_ctx, out_data, max_out_samples, nullptr, 0); + if (flush_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + flush_samples); + } + + // Build WAV output + uint32_t data_size = pcm_data.size() * sizeof(int16_t); + wav_data.resize(44 + data_size); + + wav_header_write(wav_data.data(), 1, out_sample_rate, 16, data_size); + memcpy(wav_data.data() + 44, pcm_data.data(), data_size); + + // Cleanup + av_frame_free(&frame); + av_packet_free(&packet); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + + return false; // success } + +#endif // WHISPER_COMMON_FFMPEG diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..0593b748d36 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -78,7 +78,7 @@ add_test(NAME ${TEST_TARGET} -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) set(TEST_TARGET test-whisper-cli-tiny-mp3) # Check with reviewers: any way to check the output transcription via ctest (diff, ...)? add_test(NAME ${TEST_TARGET} From 6c343e7a4ed01a77be70cc4be2f5001cc72521e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:48:05 +0300 Subject: [PATCH 195/289] common : pass sample rate to `ffmpeg_decode_audio()` --- examples/common-whisper.cpp | 2 +- examples/ffmpeg-transcode.cpp | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 8cdd2320c17..c84e6843adc 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -36,7 +36,7 @@ #ifdef WHISPER_COMMON_FFMPEG // as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support -extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data, int out_sample_rate = WHISPER_SAMPLE_RATE); #endif // extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index dc57fe74596..7657af69823 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,7 +1,5 @@ #ifdef WHISPER_COMMON_FFMPEG -#include "whisper.h" - #include #include #include @@ -44,7 +42,7 @@ static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, return 44; } -bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data) { +bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data, int out_sample_rate) { { const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); if (verbose && strcmp(verbose, "2") == 0) { @@ -116,7 +114,6 @@ bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_ // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; - const int out_sample_rate = WHISPER_SAMPLE_RATE; AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; From 2e045a967b802564844fa17cf19792c8cf1f04ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:45:44 +0300 Subject: [PATCH 196/289] ci : remove obsolete self-hosted label --- .github/workflows/build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e855ef7cf87..773122a0f0a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1490,7 +1490,7 @@ jobs: LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA] + runs-on: [self-hosted, Linux, NVIDIA] steps: - name: Clone @@ -1504,7 +1504,7 @@ jobs: GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA] + runs-on: [self-hosted, Linux, NVIDIA] steps: - name: Clone @@ -1518,7 +1518,7 @@ jobs: GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA, COOPMAT2] + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] steps: - name: Clone From 099af1c67d26172e2607a57d945e6c4a19a57a6f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 16:04:12 +0300 Subject: [PATCH 197/289] pi : add config [no ci] --- .gitignore | 3 +++ .pi/gg/SYSTEM.md | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 .pi/gg/SYSTEM.md diff --git a/.gitignore b/.gitignore index 6eb8ff45915..7a98228af3c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ cmake-build-debug/ local.properties .log .exe + +# AGENTS +.pi/SYSTEM.md diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md new file mode 100644 index 00000000000..1ae0e40674e --- /dev/null +++ b/.pi/gg/SYSTEM.md @@ -0,0 +1,27 @@ +You are a coding agent. Here are some very important rules that you must follow: + +General: +- Be very precise and concise when writing code, comments, explanations, etc. +- PR and commit titles format: ` : `. Lookup recents for examples +- Don't try to build or run the code unless you are explicitly asked to do so +- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources + +Coding: +- When in doubt, always refer to the CONTRIBUTING.md file of the project +- When referencing issues or PRs in comments, use the format: + - C/C++ code: `// ref: <url>` + - Other (CMake, etc.): `# ref: <url>` + +Pull requests (PRs): +- New branch names are prefixed with "gg/" +- Before opening a pull request, ask the user to confirm the description +- When creating a pull request, look for the repository's PR template and follow it +- For the AI usage disclosure section, write "YES. llama.cpp + pi + [MODEL]" +- Ask the user to tell you what model was used and write it in place of [MODEL] +- Always create the pull requests in draft mode + +Commits: +- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag +- Do not explicitly set the git author in commits - rely on the default git config +- Always use `--no-gpg-sign` when committing +- Never `git push` without explicit confirmation from the user From fe69461618ffc50ba8afa65c25cc6c6e34d4537f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Sun, 31 May 2026 16:06:32 +0300 Subject: [PATCH 198/289] ci : fix self-hosted paths to mnt --- .github/workflows/build.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 773122a0f0a..878c5833eaa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1501,7 +1501,7 @@ jobs: id: ggml-ci run: | nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm: runs-on: [self-hosted, Linux, NVIDIA] @@ -1515,7 +1515,7 @@ jobs: id: ggml-ci run: | vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm2: runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] @@ -1529,7 +1529,7 @@ jobs: id: ggml-ci run: | vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp #ggml-ci-x64-cpu-amx: # runs-on: [self-hosted, Linux, X64, CPU, AMX] @@ -1542,7 +1542,7 @@ jobs: # - name: Test # id: ggml-ci # run: | - # bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + # bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] From 0dff27498f704b9eab8527f03c769efb7e7f051c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Mon, 1 Jun 2026 07:20:19 +0200 Subject: [PATCH 199/289] ci : fix path to whisper.h in examples.yml [no ci] (#3842) This commit updates the include path to whisper.h and also ensures that this is only built on pushes to master. --- .github/workflows/build.yml | 5 +++-- .github/workflows/examples.yml | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 878c5833eaa..b7badd51041 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,8 +29,9 @@ on: pull_request: types: [opened, synchronize, reopened] paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml workflow_dispatch: inputs: create_release: diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 1c9ade5a300..df3aa832c2e 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -1,13 +1,15 @@ name: Examples Tests on: push: + branches: + - master paths: - examples/addon.node/** - - whisper.h + - include/whisper.h pull_request: paths: - examples/addon.node/** - - whisper.h + - include/whisper.h jobs: addon_node-ubuntu-22: From 23ee03506a91ac3d3f0071b40e66a430eebdfa1d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 1 Jun 2026 14:56:20 +0300 Subject: [PATCH 200/289] release : v1.8.6 --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 35c8674725f..4df278c3ad8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.5) +project("whisper.cpp" VERSION 1.8.6) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index d1680e99bfc..fe7fa74153a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.8.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index caf12b6dd2d..1f2f34672ae 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.5", + "version": "1.8.6", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From ef24de1e5814c4fb14cc396aa6aed623032073ab Mon Sep 17 00:00:00 2001 From: Patrice Levesque <github-wayne@ptaff.ca> Date: Tue, 2 Jun 2026 03:22:16 -0400 Subject: [PATCH 201/289] cmake : do not assume /usr/lib library installation. (#3693) Current `pkgconfig` configuration file installation path and its contents assume libraries are installed under `/usr/lib` and this is not always the case, for instance `/usr/lib64` is quite possible under Gentoo Linux. Thus use the `CMAKE_INSTALL_LIBDIR` variable instead of a hardcoded `lib`. --- CMakeLists.txt | 2 +- cmake/whisper.pc.in | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4df278c3ad8..3932cf2845e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -209,7 +209,7 @@ configure_file(cmake/whisper.pc.in @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" - DESTINATION lib/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) # # programs, examples and tests diff --git a/cmake/whisper.pc.in b/cmake/whisper.pc.in index 00ec7912014..200179d5d11 100644 --- a/cmake/whisper.pc.in +++ b/cmake/whisper.pc.in @@ -1,6 +1,6 @@ prefix=@CMAKE_INSTALL_PREFIX@ exec_prefix=${prefix} -libdir=${exec_prefix}/lib +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ includedir=${prefix}/include Name: whisper From e5d44125788a69cca621c85c4d022e83162ac113 Mon Sep 17 00:00:00 2001 From: Noah Lyons <n.lyons53@gmail.com> Date: Tue, 2 Jun 2026 07:10:27 -0400 Subject: [PATCH 202/289] server : merge split utf-8 token text in verbose json (#3850) --- examples/cli/cli.cpp | 33 --------------------------------- examples/common-whisper.cpp | 28 ++++++++++++++++++++++++++++ examples/common-whisper.h | 3 +++ examples/server/server.cpp | 23 +++++++++++++++++++++-- tests/CMakeLists.txt | 8 ++++++++ tests/test-common-utf8.cpp | 34 ++++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 35 deletions(-) create mode 100644 tests/test-common-utf8.cpp diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 55cd71b4e55..7ca563dc250 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -31,39 +31,6 @@ static void replace_all(std::string & s, const std::string & search, const std:: } } -// Returns the number of trailing continuation bytes still needed for `s` to end -// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a -// complete codepoint (or if the tail looks malformed and we should stop merging). -// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character -// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798. -static int utf8_trailing_bytes_needed(const std::string & s) { - const int n = (int) s.size(); - int i = n - 1; - // walk back past continuation bytes (10xxxxxx) - while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { - --i; - } - if (i < 0) { - // all continuation bytes, or empty — nothing we can do - return 0; - } - const unsigned char c = (unsigned char) s[i]; - int expected; - if ((c & 0x80) == 0x00) { - expected = 1; // ASCII - } else if ((c & 0xE0) == 0xC0) { - expected = 2; - } else if ((c & 0xF0) == 0xE0) { - expected = 3; - } else if ((c & 0xF8) == 0xF0) { - expected = 4; - } else { - return 0; // malformed lead, give up - } - const int have = n - i; - return have >= expected ? 0 : (expected - have); -} - // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index c84e6843adc..b12481c013f 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -198,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); } +int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + return 0; + } + + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; + } + + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) { std::ofstream speak_file(path.c_str()); if (speak_file.fail()) { diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 8714c381046..aec430d3635 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -28,5 +28,8 @@ std::string to_timestamp(int64_t t, bool comma = false); // given a timestamp get the sample int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint. +int utf8_trailing_bytes_needed(const std::string & s); + // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aae74c3d840..b87ef27375f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,10 +1107,29 @@ int main(int argc, char ** argv) { } segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + std::string word_text = whisper_full_get_token_text(ctx, i, j); + int64_t word_t1 = token.t1; + + while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { + const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); + // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. + if (next_token.id >= whisper_token_eot(ctx)) { + break; + } + + ++j; + segment["tokens"].push_back(next_token.id); + word_text += whisper_full_get_token_text(ctx, i, j); + if (next_token.t1 > -1) { + word_t1 = next_token.t1; + } + total_logprob += next_token.plog; + } + + json word = json{{"word", word_text}}; if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; + word["end"] = word_t1 * 0.01; word["t_dtw"] = token.t_dtw; } word["probability"] = token.p; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0593b748d36..646f45f2ab7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -88,6 +88,14 @@ if (WHISPER_COMMON_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# UTF-8 helper unit test +set(UTF8_TEST test-common-utf8) +add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp) +target_include_directories(${UTF8_TEST} PRIVATE ../examples) +target_link_libraries(${UTF8_TEST} PRIVATE common) +add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST}) +set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit") + # VAD test tests VAD in isolation set(VAD_TEST test-vad) add_executable(${VAD_TEST} ${VAD_TEST}.cpp) diff --git a/tests/test-common-utf8.cpp b/tests/test-common-utf8.cpp new file mode 100644 index 00000000000..91c73a7428d --- /dev/null +++ b/tests/test-common-utf8.cpp @@ -0,0 +1,34 @@ +#include "common-whisper.h" + +#include <cstdlib> +#include <cstdio> +#include <string> + +static void expect_needed(const std::string & input, int expected) { + const int actual = utf8_trailing_bytes_needed(input); + if (actual != expected) { + fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual); + std::abort(); + } +} + +int main() { + expect_needed("", 0); + expect_needed("plain ascii", 0); + + const std::string cjk = "\xE4\xBD\xA0"; // U+4F60 + expect_needed(cjk.substr(0, 1), 2); + expect_needed(cjk.substr(0, 2), 1); + expect_needed(cjk, 0); + + const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600 + expect_needed(emoji.substr(0, 1), 3); + expect_needed(emoji.substr(0, 2), 2); + expect_needed(emoji.substr(0, 3), 1); + expect_needed(emoji, 0); + + expect_needed("\x80\x80", 0); + expect_needed("\xFF", 0); + + return 0; +} From 610e664ba7cfe3af46125ed1b5a1184fccb51bcd Mon Sep 17 00:00:00 2001 From: danscMax <153344025+danscMax@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:25:29 +0300 Subject: [PATCH 203/289] whisper : catch C++ exceptions in whisper_init_with_params_no_state (#3831) whisper_model_load() can throw instead of returning false: std::runtime_error from this file (failed ggml context / no compatible buffer type), or vk::SystemError / vk::OutOfDeviceMemoryError from the ggml-vulkan backend during device/buffer allocation. whisper_init_* are extern "C", so a C++ exception unwinding across that boundary aborts non-C++ callers (Rust via whisper-rs, Go via cgo) -- on Windows STATUS_STACK_BUFFER_OVERRUN (0xC0000409) -- even though the function already returns NULL on failure. Wrap whisper_model_load() in try/catch and route any throw into the existing NULL-return path. Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> --- src/whisper.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 0fe29a4541e..5ffc70af00e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3720,7 +3720,21 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ whisper_context * ctx = new whisper_context; ctx->params = params; - if (!whisper_model_load(loader, *ctx)) { + // A C++ exception escaping this extern "C" function aborts non-C++ callers + // (Rust via whisper-rs, Go via cgo, ...). whisper_model_load can throw + // (std::runtime_error here; vk::SystemError from the Vulkan backend during + // device/buffer allocation), so funnel any throw into the existing + // NULL-return failure path instead of letting it cross the C ABI. + bool model_loaded = false; + try { + model_loaded = whisper_model_load(loader, *ctx); + } catch (const std::exception & e) { + WHISPER_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + WHISPER_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { loader->close(loader->context); WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; From 02d5316af5f9cef149ec20eebcba99cd6395b6b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 4 Jun 2026 09:35:58 +0300 Subject: [PATCH 204/289] ci : refactor + optimize (#3847) * ci : add ccache clear action * ci : split self-hosted GPU jobs into build-self-hosted.yml Extract self-hosted runner jobs from build.yml into a dedicated build-self-hosted.yml following the llama.cpp pattern: - gpu-cuda (NVIDIA Linux) - gpu-vulkan-nvidia-cm (NVIDIA Linux) - gpu-vulkan-nvidia-cm2 (NVIDIA Linux + COOPMAT2) - gpu-metal (macOS ARM64) - gpu-vulkan (macOS ARM64) GitHub-hosted CPU jobs remain in build.yml. Assisted-by: llama.cpp:local pi * ci : split release jobs into release.yml Extract release-related jobs from build.yml into a dedicated release.yml following the llama.cpp pattern: - determine-tag - windows (Win32/x64, SDL2) - windows-blas (Win32/x64, OpenBLAS) - windows-cublas (x64, CUDA 11.8/12.4) - ios-xcode-build - bindings-java (depends on windows) - release (artifact aggregation + GitHub release) CoreML job stays in build.yml with its own local tag calculation. Assisted-by: llama.cpp:local pi * ci : remove bindings-java job from release.yml Assisted-by: llama.cpp:local pi * cont : add manual trigger for build.yml * cont : remove obsolete ifs * ci : extract sanitizer job to bild-sanitize.yml * ci : extract linux jobs into build-linux.yml * ci : extract macos jobs to build-macos.yml * ci : extract gcc jobs to build-gcc.yml * ci : extract clang jobs to build-clang.yml * ci : extract sycl jobs to build-sycl.yml * ci : extract windows jobs to build-windows.yml * ci : extract emscripten job to build-wasm.yml * ci : extract android jobs into build-android.yml * ci : extract quantize job to quantize.yml * ci : extract coreml job into coreml.yml * ci : extract vad job to vad.yml * ci : extract cpu jobs to build-cpu.yml * ci : make naming of yml files consistent * ci : add --fail to curl download and propagate This commit adds the --fail option to the model download scripts so that if the model download returns a server error this is picked up. This is then detected in run.sh and a error message is displayed and the script stops and returns an error. The motivation for this is that currently it is possible for the model download to fail but this script proceeds and instead of a model file the contents will be an html page probably with the error. This will then cause the model to not be able to load due to a missing magic number. I'm not sure we can do much about the downloading failing, perhaps a retry but at least this will give a clearer error message. Refs: https://github.com/danbev/whisper.cpp/actions/runs/26866349389/job/79230794512 * ci : enable command traces to see download command in use * ci : add retry functionality to download model script This commit adds curl retry options to the model download script. The motivation is that currently when CI jobs run huggingface rate limit the requests and return: ```console curl: (22) The requested URL returned error: 429 ``` This is an attempt to work around this and if it does not work then we can an authorization token. * ci : extract freebsd job to build-freebsd.yml This job has been commented out as it has been flaky in the past. I'll monitor this and if it continues to be unreliable we can disable it in the github actions GUI instead of commenting it out like we did before. * ci : add ccache to jobs (non-docker builds) The ccache will only be saved on pushed to master. * ci : bump ccache-action version to v1.2.21 The motivation for this is that the save parameter does not seem to work with the current version. * ci : add ccache to docker jobs in build-linux.yml * ci : add debug statements to linux docker build * ci : set CCACHE_DIR for build-linux.yml * ci : add ccache to the remaining docker jobs * ci : remove build-linux.yml This commit remove build-linux.yml as the same jobs are also run by build-gcc.yml, with the exception that build-gcc.yml also run ctest). So keeping build-gcc.yml and removing the redundant build-linux.yml. * ci : add linux build artifacts to release * ci : revert to hendrikmuhs/ccache-action for win job This is currently causing the following failure: ```console sccache C:\PROGRA~1\NVIDIA~1\CUDA\v\bin\nvcc.exe -forward-unknown-to-host-compiler -DGGML_BACKEND_BUILD -DGGML_BACKEND_SHARED -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_SCHED_MAX_COPIES=4 -DGGML_SHARED -D_CRT_SECURE_NO_WARNINGS -D_XOPEN_SOURCE=600 -Dggml_cuda_EXPORTS -DCMAKE_INTDIR=\"Release\" -ID:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\.. -ID:\a\whisper.cpp\whisper.cpp\ggml\src\..\include -isystem "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v\include" -Xcompiler="-MD -O2 -Ob2" -DNDEBUG -std=c++17 -arch=native -use_fast_math -extended-lambda -Xcompiler /Zc:preprocessor -MD -MT ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -MF ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj.d -x cu -c D:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\allreduce.cu -o ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -Xcompiler=-Fdggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\,-FS sccache: encountered fatal error sccache: error: Could not parse shell line sccache: caused by: Could not parse shell line ``` Refs: https://github.com/danbev/whisper.cpp/actions/runs/26883673904/job/79290017353 * ci : make static linux artifacts * ci : make linux release artifact names consistent This commit removes the tag form the linux release artifacts to be consistent with the existing artifacts. If we want to include the tag then we can do that in a follow-up PR. * ci : fix linux zip files to have a directory * ci : add HF_TOKEN secret for HF download authorization This is to avoid the HR rate limiting when downloading model. --------- Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> --- .github/actions/ccache-clear/action.yml | 22 + .github/workflows/build-android.yml | 80 ++ .github/workflows/build-clang.yml | 121 ++ .github/workflows/build-coreml.yml | 65 + .github/workflows/build-cpu.yml | 173 +++ .github/workflows/build-freebsd.yml | 47 + .github/workflows/build-gcc.yml | 166 +++ .github/workflows/build-macos.yml | 72 ++ .github/workflows/build-quantize.yml | 41 + .github/workflows/build-sanitize.yml | 82 ++ .github/workflows/build-self-hosted.yml | 116 ++ .github/workflows/build-sycl.yml | 132 ++ .github/workflows/build-vad.yml | 43 + .github/workflows/build-wasm.yml | 51 + .github/workflows/build-windows.yml | 76 ++ .github/workflows/build.yml | 1573 ----------------------- .github/workflows/examples.yml | 2 + .github/workflows/release.yml | 649 ++++++++++ ci/run.sh | 7 + models/download-ggml-model.sh | 8 +- 20 files changed, 1952 insertions(+), 1574 deletions(-) create mode 100644 .github/actions/ccache-clear/action.yml create mode 100644 .github/workflows/build-android.yml create mode 100644 .github/workflows/build-clang.yml create mode 100644 .github/workflows/build-coreml.yml create mode 100644 .github/workflows/build-cpu.yml create mode 100644 .github/workflows/build-freebsd.yml create mode 100644 .github/workflows/build-gcc.yml create mode 100644 .github/workflows/build-macos.yml create mode 100644 .github/workflows/build-quantize.yml create mode 100644 .github/workflows/build-sanitize.yml create mode 100644 .github/workflows/build-self-hosted.yml create mode 100644 .github/workflows/build-sycl.yml create mode 100644 .github/workflows/build-vad.yml create mode 100644 .github/workflows/build-wasm.yml create mode 100644 .github/workflows/build-windows.yml delete mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/release.yml diff --git a/.github/actions/ccache-clear/action.yml b/.github/actions/ccache-clear/action.yml new file mode 100644 index 00000000000..d38587efaf8 --- /dev/null +++ b/.github/actions/ccache-clear/action.yml @@ -0,0 +1,22 @@ +name: "ccache-clear" +description: "Delete all GitHub Actions caches matching a key prefix" +inputs: + key: + description: "Cache key prefix to match and delete" + required: true + +runs: + using: "composite" + steps: + - name: Clear caches + shell: bash + run: | + CACHES=$(gh cache list --key "ccache-${{ inputs.key }}" --json id,key --jq '.[] | "\(.id) \(.key)"' 2>/dev/null) + if [ -z "$CACHES" ]; then + echo "No caches found with key prefix: ${{ inputs.key }}" + exit 0 + fi + while read -r id key; do + echo "Deleting cache: $id ($key)" + gh cache delete "$id" + done <<< "$CACHES" diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml new file mode 100644 index 00000000000..d9af1810131 --- /dev/null +++ b/.github/workflows/build-android.yml @@ -0,0 +1,80 @@ +name: CI (android) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-android.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.java'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + android: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + with: + path: whisper + + - name: Install Java + uses: actions/setup-java@v5 + with: + distribution: zulu + java-version: 21 + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + + - name: Build + run: | + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + - name: Build with external ggml + run: | + export PATH_TO_GGML=$PWD/ggml + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + android_java: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: set up JDK 11 + uses: actions/setup-java@v5 + with: + java-version: '11' + distribution: 'temurin' + cache: gradle + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + with: + cmdline-tools-version: 9.0 + + - name: Build + run: | + cd examples/whisper.android.java + chmod +x ./gradlew + ./gradlew assembleRelease diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml new file mode 100644 index 00000000000..c7a36884f64 --- /dev/null +++ b/.github/workflows/build-clang.yml @@ -0,0 +1,121 @@ +name: CI (clang) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-clang.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-clang: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + # TODO: arm/v7 disabled due to clang bug + # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y clang build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-clang-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang build-essential cmake libsdl2-dev git + + - name: Build and Test + run: | + cmake . -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + make + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml new file mode 100644 index 00000000000..d383d9ae0a7 --- /dev/null +++ b/.github/workflows/build-coreml.yml @@ -0,0 +1,65 @@ +name: CI (coreml) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + tags: + - 'v*' + paths: ['.github/workflows/build-coreml.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + +jobs: + coreml-base-en: + runs-on: macos-latest + + steps: + - name: Checkout with full history + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set environment variables + id: set_vars + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + if [[ "${{ github.ref_type }}" == "tag" ]]; then + TAG_NAME="${{ github.ref_name }}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + TAG_NAME="b${BUILD_NUMBER}" + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + fi + echo "MODEL_NAME=base.en" >> $GITHUB_ENV + echo "GEN_MODEL_NAME=whisper-${TAG_NAME}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV + + - name: Download model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} + + - name: Generate CoreML model + run: | + python3.11 -m venv venv + source venv/bin/activate + pip install ane_transformers openai-whisper coremltools + ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml new file mode 100644 index 00000000000..9c8e0586fcb --- /dev/null +++ b/.github/workflows/build-cpu.yml @@ -0,0 +1,173 @@ +name: CI (cpu) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-cpu.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +# TODO: simplify the following jobs using a matrix +jobs: + ggml-ci-x64-cpu-low-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-low-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-x64-cpu-high-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf-sve: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf-sve + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml new file mode 100644 index 00000000000..847ae975e30 --- /dev/null +++ b/.github/workflows/build-freebsd.yml @@ -0,0 +1,47 @@ +name: CI (freebsd) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-freebsd.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + freeBSD-latest: + runs-on: macos-13 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Build + uses: cross-platform-actions/action@v0.27.0 + with: + operating_system: freebsd + version: '14.2' + run: | + sudo pkg update + sudo pkg install -y gmake sdl2 cmake git + cmake -B build + cmake --build build --config Release diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml new file mode 100644 index 00000000000..4528ba3d534 --- /dev/null +++ b/.github/workflows/build-gcc.yml @@ -0,0 +1,166 @@ +name: CI (gcc) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-gcc.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-gcc: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-gcc-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libsdl2-dev git + + - name: Configure CMake + run: | + cmake . \ + -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + + - name: Build and Test + run: | + make + ctest -L gh --output-on-failure + + ubuntu-22-gcc-arm-v7: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/arm/v7] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv7-a+fp \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml new file mode 100644 index 00000000000..804f8bbb642 --- /dev/null +++ b/.github/workflows/build-macos.yml @@ -0,0 +1,72 @@ +name: CI (macOS) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-macos.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + macOS-latest: + runs-on: macOS-latest + + strategy: + matrix: + destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: macos-${{ matrix.destination }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + run: | + brew update + cmake --version + brew install sdl2 + + - name: Build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml new file mode 100644 index 00000000000..8036a3a3450 --- /dev/null +++ b/.github/workflows/build-quantize.yml @@ -0,0 +1,41 @@ +name: CI (quantize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-quantize.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + quantize: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Test quantize + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh tiny.en + cmake -B build + cmake --build build --config Release + ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml new file mode 100644 index 00000000000..9250fe81023 --- /dev/null +++ b/.github/workflows/build-sanitize.yml @@ -0,0 +1,82 @@ +name: CI (sanitize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sanitize.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' + - 'bindings/go/**' + - 'examples/addon.node/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-gcc-sanitized: + runs-on: ubuntu-22.04 + + continue-on-error: true + + strategy: + fail-fast: false + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sanitize-${{ matrix.sanitizer }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake git + + - name: Build (undefined) + if: ${{ matrix.sanitizer == 'UNDEFINED' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=Debug \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Build + if: ${{ matrix.sanitizer == 'ADDRESS' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON + make + + - name: Build (no OpenMP) + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Test + if: ${{ matrix.sanitizer != 'UNDEFINED' }} + run: | + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml new file mode 100644 index 00000000000..3fe131b9ba5 --- /dev/null +++ b/.github/workflows/build-self-hosted.yml @@ -0,0 +1,116 @@ +name: CI (self-hosted) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-self-hosted.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + gpu-cuda: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + nvidia-smi + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm2: + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-metal: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml new file mode 100644 index 00000000000..57aa7cc4d95 --- /dev/null +++ b/.github/workflows/build-sycl.yml @@ -0,0 +1,132 @@ +name: CI (sycl) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sycl.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-cmake-sycl: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp git + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel git + + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) + + ubuntu-22-cmake-sycl-fp16: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp git + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml new file mode 100644 index 00000000000..71e910a3fcb --- /dev/null +++ b/.github/workflows/build-vad.yml @@ -0,0 +1,43 @@ +name: CI (vad) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-vad.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + vad: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Build + shell: bash + run: | + cmake -B build + cmake --build build --config Release + + - name: Test + shell: bash + run: | + ctest -R ^test-vad$ --test-dir build --output-on-failure -VV diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml new file mode 100644 index 00000000000..42a9401af3c --- /dev/null +++ b/.github/workflows/build-wasm.yml @@ -0,0 +1,51 @@ +name: CI (wasm) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-wasm.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + emscripten: + runs-on: ubuntu-22.04 + + strategy: + matrix: + build: [Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Setup emsdk + uses: mymindstorm/setup-emsdk@v14 + + - name: Verify + run: emcc -v + + - name: Build + run: | + emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} + make diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml new file mode 100644 index 00000000000..cd1591f0132 --- /dev/null +++ b/.github/workflows/build-windows.yml @@ -0,0 +1,76 @@ +name: CI (windows) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-windows.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + windows-msys2: + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + include: + - { sys: UCRT64, env: ucrt-x86_64, build: Release } + - { sys: CLANG64, env: clang-x86_64, build: Release } + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Setup ${{ matrix.sys }} + uses: msys2/setup-msys2@v2 + with: + update: true + msystem: ${{matrix.sys}} + install: >- + base-devel + git + mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-SDL2 + mingw-w64-${{matrix.env}}-openblas + + - name: Build using CMake + shell: msys2 {0} + run: | + cmake -B build -DWHISPER_SDL2=ON + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + - name: Clean after building using CMake + shell: msys2 {0} + run: | + rm -rf build + + - name: Build using CMake w/ OpenBLAS + shell: msys2 {0} + run: | + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS + cmake --build build --config ${{ matrix.build }} -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index b7badd51041..00000000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,1573 +0,0 @@ -name: CI - -on: - push: - branches: - - master - tags: - - 'v*' - paths: ['.github/workflows/build.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal', - '**/*.comp', - '**/*.java'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - workflow_dispatch: - inputs: - create_release: - description: 'Create new release' - required: true - type: boolean - pre_release_tag: - description: 'Pre-release tag name' - required: false - type: string - run_type: - description: 'Workflow type to run' - required: true - type: choice - options: - - full-ci - - release-only - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -permissions: - contents: write # for creating release - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - ubuntu_image: "ubuntu:22.04" - VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" - -jobs: - determine-tag: - runs-on: ubuntu-latest - outputs: - tag_name: ${{ steps.tag.outputs.name }} - should_release: ${{ steps.tag.outputs.should_release }} - - steps: - - name: Checkout with full history - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER=$(git rev-list --count HEAD) - SHORT_HASH=$(git rev-parse --short=7 HEAD) - CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" - SHOULD_RELEASE="false" - - echo "Raw values:" - echo "BUILD_NUMBER: $BUILD_NUMBER" - echo "SHORT_HASH: $SHORT_HASH" - echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" - echo "CUSTOM_TAG: $CUSTOM_TAG" - - if [[ "${{ github.ref_type }}" == "tag" ]]; then - echo "Using pushed tag name" - TAG_NAME="${{ github.ref_name }}" - SHOULD_RELEASE="true" - elif [[ -n "$CUSTOM_TAG" ]]; then - echo "Using custom tag" - TAG_NAME="${CUSTOM_TAG}" - SHOULD_RELEASE="true" - elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then - echo "Manual release requested" - SHOULD_RELEASE="true" - TAG_NAME="b${BUILD_NUMBER}" - elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "Using master branch format" - TAG_NAME="b${BUILD_NUMBER}" - SHOULD_RELEASE="false" - else - echo "Using non-master branch format" - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" - SHOULD_RELEASE="false" - fi - - echo "Final tag name: $TAG_NAME" - echo "Should release: $SHOULD_RELEASE" - echo "name=$TAG_NAME" >> $GITHUB_OUTPUT - echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT - - - ubuntu-22: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build - cmake --build build --config Release -j $(nproc)' - - ubuntu-22-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential libsdl2-dev cmake git - - - name: Build - run: | - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - cmake --build build --config Release -j $(nproc) - - ubuntu-22-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - cmake --build build --config Release -j $(nproc)' - - macOS-latest: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: macOS-latest - - strategy: - matrix: - destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: macOS-latest-swift - evict-old-files: 1d - - - name: Dependencies - run: | - brew update - cmake --version - brew install sdl2 - - - name: Build - run: | - sysctl -a - cmake -B build -G Xcode \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" - cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - - -# freeBSD-latest: -# runs-on: macos-13 -# -# steps: -# - name: Clone -# uses: actions/checkout@v6 -# -# - name: Build -# uses: cross-platform-actions/action@v0.27.0 -# with: -# operating_system: freebsd -# version: '14.2' -# run: | -# sudo pkg update -# sudo pkg install -y gmake sdl2 cmake git -# cmake -B build -# cmake --build build --config Release - - ubuntu-22-gcc: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake libsdl2-dev git - - - name: Configure CMake - run: | - cmake . \ - -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - - - name: Build and Test - run: | - make - ctest -L gh --output-on-failure - - ubuntu-22-gcc-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - # TODO: arm/v7 disabled due to clang bug - # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y clang build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y clang build-essential cmake libsdl2-dev git - - - name: Build and Test - run: | - cmake . -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER=clang \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure - - ubuntu-22-gcc-sanitized: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - sanitizer: [ADDRESS, THREAD, UNDEFINED] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake git - - - name: Build and Test - run: | - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - ctest -L gh --output-on-failure - - ubuntu-22-cmake-sycl: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel git - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - ubuntu-22-cmake-sycl-fp16: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - windows-msys2: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - fail-fast: false - matrix: - include: - - { sys: UCRT64, env: ucrt-x86_64, build: Release } - - { sys: CLANG64, env: clang-x86_64, build: Release } - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@v2 - with: - update: true - msystem: ${{matrix.sys}} - install: >- - base-devel - git - mingw-w64-${{matrix.env}}-toolchain - mingw-w64-${{matrix.env}}-cmake - mingw-w64-${{matrix.env}}-SDL2 - mingw-w64-${{matrix.env}}-openblas - - - name: Build using CMake - shell: msys2 {0} - run: | - cmake -B build -DWHISPER_SDL2=ON - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - - name: Clean after building using CMake - shell: msys2 {0} - run: | - rm -rf build - - - name: Build using CMake w/ OpenBLAS - shell: msys2 {0} - run: | - cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - windows: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - sdl2: [ON] - include: - - arch: Win32 - s2arc: x86 - jnaPath: win32-x86 - - arch: x64 - s2arc: x64 - jnaPath: win32-x86-64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DBUILD_SHARED_LIBS=ON - -DWHISPER_SDL2=${{ matrix.sdl2 }} - -DGGML_NATIVE=OFF - -DGGML_BMI2=OFF - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Upload SDL2.dll - if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v6 - with: - name: ${{ matrix.s2arc }}_SDL2.dll - path: build/bin/${{ matrix.build }}/SDL2.dll - - - name: Upload whisper dll - uses: actions/upload-artifact@v6 - with: - name: whisper_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/whisper.dll - - - name: Upload ggml dll - uses: actions/upload-artifact@v6 - with: - name: ggml_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml.dll - overwrite: true - - - name: Upload ggml base dll - uses: actions/upload-artifact@v6 - with: - name: ggml_base_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-base.dll - - - name: Upload ggml cpu dll - uses: actions/upload-artifact@v6 - with: - name: ggml_cpu_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-cpu.dll - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-bin-${{ matrix.arch }}.zip - path: whisper-bin-${{ matrix.arch }}.zip - - windows-blas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - blas: [ON] - sdl2: [ON] - blasver: [0.3.29] - include: - - arch: Win32 - s2arc: x86 - blasfile: x86 - - arch: x64 - s2arc: x64 - blasfile: x64_64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v8 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install OpenBLAS and pkgconfiglite - if: matrix.blas == 'ON' - run: | - Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" - Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" - choco install pkgconfiglite - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DGGML_BLAS=${{ matrix.blas }} - -DGGML_BLAS_VENDOR=OpenBLAS - -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" - -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" - -DWHISPER_SDL2=${{ matrix.sdl2 }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy openblas.dll - if: matrix.blas == 'ON' - run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-blas-bin-${{ matrix.arch }}.zip - path: whisper-blas-bin-${{ matrix.arch }}.zip - - windows-cublas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-2022 - needs: determine-tag - strategy: - fail-fast: false - matrix: - build: [Release] - arch: [x64] - cublas: [ON] - sdl2: [ON] - cuda-toolkit: [12.4.0, 11.8.0] - include: - - arch: x64 - sdl2: ON - sdl2_ver: 2.28.5 - steps: - - name: Clone repository - uses: actions/checkout@v6 - - - name: Install Ninja - id: install_ninja - run: | - choco install ninja - - - name: Install ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - variant: sccache - evict-old-files: 5d - - - name: Install Cuda Toolkit 11.8.0 - if: ${{ matrix.cuda-toolkit == '11.8.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "11.8.89" - $NVCC_VER = "11.8.89" - $NVRTC_VER = "11.8.89" - $CUBLAS_VER = "11.8.1.74" - $NVTX_VER = "11.8.86" - $VS_VER = "11.8.86" - $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Install Cuda Toolkit 12.4.0 - if: ${{ matrix.cuda-toolkit == '12.4.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $NVPROF_VER = "12.4.128" - $CCCL_VER = "12.4.127" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install 7-Zip - run: choco install 7zip -y - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip - 7z x sdl2.zip - echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append - echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt - - - name: Install cmake - run: choco install cmake - - - name: Build Project - shell: cmd - run: | - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake --version - where cmake - if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR - ) else ( - set CUDA_FLAGS= - ) - cmake -S . -B build -G "Ninja Multi-Config" ^ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ - -DGGML_CUDA=${{ matrix.cublas }} ^ - -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ - -DSDL2_DIR="%SDL2_DIR%" ^ - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" - set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 - cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - - name: Check sccache status after build - run: | - sccache --show-stats - - - name: Copy CUDA DLLs - run: | - Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | - Copy-Item -Destination "build/bin/${{ matrix.build }}" - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - - emscripten: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - matrix: - build: [Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 - - - name: Verify - run: emcc -v - - - name: Build - run: | - emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - - ios-xcode-build: - runs-on: macos-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Configure - run: | - cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin - mkdir models/ggml-base.en-encoder.mlmodelc - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ - -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO - - - name: xcodebuild for swift package - id: xcodebuild - run: | - ./build-xcframework.sh - - - name: Build objc example - run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Build swiftui example - run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Pack artifacts - id: pack_artifacts - run: | - zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework - - - name: Upload artifacts - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - - android: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - with: - path: whisper - - - name: Install Java - uses: actions/setup-java@v5 - with: - distribution: zulu - java-version: 21 - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - - - name: Build - run: | - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - - name: Build with external ggml - run: | - export PATH_TO_GGML=$PWD/ggml - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - android_java: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: set up JDK 11 - uses: actions/setup-java@v5 - with: - java-version: '11' - distribution: 'temurin' - cache: gradle - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - with: - cmdline-tools-version: 9.0 - - - name: Build - run: | - cd examples/whisper.android.java - chmod +x ./gradlew - ./gradlew assembleRelease - - bindings-java: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - needs: ['windows'] - runs-on: windows-latest - steps: - - uses: actions/checkout@v6 - - - name: Install Java - uses: actions/setup-java@v5 - with: - distribution: zulu - java-version: 20 - - - name: Download Whisper Windows lib - uses: actions/download-artifact@v7 - with: - name: whisper_x64.dll - - - name: Download GGML Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_x64.dll - - - name: Download GGML Base Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_base_x64.dll - - - name: Download GGML CPU Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_cpu_x64.dll - - - name: Download SDL2.dll - uses: actions/download-artifact@v7 - with: - name: x64_SDL2.dll - - - name: List downloaded files - shell: pwsh - run: | - Get-ChildItem -Path "." -Recurse -Filter "*.dll" - - - name: Move DLL to correct location - shell: pwsh - run: | - New-Item -Path "build\bin\Release" -ItemType Directory -Force - - Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force - Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory" - - Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force - Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory" - - Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force - Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory" - - Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force - Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory" - - Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force - Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory" - - - name: List build release files - shell: pwsh - run: | - Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll" - - - name: Build - run: | - models\download-ggml-model.cmd tiny.en models/ - cd bindings/java - chmod +x ./gradlew - ./gradlew build --info - - - name: Pack jar artifacts - shell: pwsh - run: | - Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip" - - - name: Upload jar - uses: actions/upload-artifact@v6 - with: - name: whispercpp.jar.zip - path: whispercpp.jar.zip - -# - name: Publish package -# if: ${{ github.ref == 'refs/heads/master' }} -# uses: gradle/gradle-build-action@v2.4.2 -# with: -# arguments: publish -# build-root-directory: bindings/java -# env: -# MAVEN_USERNAME: ${{ secrets.JIRA_USER }} -# MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }} -# PGP_SECRET: ${{ secrets.GPG_PRIVATE_KEY }} -# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} - - quantize: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Test quantize - run: | - ./models/download-ggml-model.sh tiny.en - cmake -B build - cmake --build build --config Release - ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 - - release: - if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} - - runs-on: ubuntu-latest - - needs: - - determine-tag - - ios-xcode-build - - windows - - windows-blas - - windows-cublas - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: release - evict-old-files: 1d - - # Downloads all the artifacts from the previous jobs - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@v7 - with: - path: ./artifact - - - name: Move artifacts - id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release - - - name: Create release - id: create_release - uses: ggml-org/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ needs.determine-tag.outputs.tag_name }} - prerelease: ${{ github.event.inputs.pre_release_tag != '' }} - draft: true - - - name: Upload release - id: upload_release - uses: actions/github-script@v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { - if (path.extname(file) === '.zip') { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) - }); - } - } - - coreml-base-en: - if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') || - github.event.inputs.create_release == 'true' || - github.event.inputs.pre_release_tag != '' || - startsWith(github.ref, 'refs/tags/v') }} - runs-on: macos-latest - needs: determine-tag - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set environment variables - id: set_vars - run: | - echo "MODEL_NAME=base.en" >> $GITHUB_ENV - echo "GEN_MODEL_NAME=whisper-${{ needs.determine-tag.outputs.tag_name }}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV - - - name: Download model - run: | - ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} - - - name: Generate CoreML model - run: | - python3.11 -m venv venv - source venv/bin/activate - pip install ane_transformers openai-whisper coremltools - ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} - - vad: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v6 - - - name: Build - shell: bash - run: | - cmake -B build - cmake --build build --config Release - - - name: Test - shell: bash - run: | - ctest -R ^test-vad$ --test-dir build --output-on-failure -VV - -# TODO: simplify the following workflows using a matrix - ggml-ci-x64-cpu-low-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-low-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-cpu-high-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf-sve: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf-sve - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - #ggml-ci-x64-cpu-amx: - # runs-on: [self-hosted, Linux, X64, CPU, AMX] - - # steps: - # - name: Clone - # id: checkout - # uses: actions/checkout@v6 - - # - name: Test - # id: ggml-ci - # run: | - # bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-mac-metal: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-mac-vulkan: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index df3aa832c2e..eaa4fe4df61 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -42,6 +42,8 @@ jobs: run: npx cmake-js compile -T addon.node -B Release - name: Download test model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | bash ./models/download-ggml-model.sh base.en - name: Test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000000..2ba8b45093b --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,649 @@ +name: Release + +on: + workflow_dispatch: + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + pre_release_tag: + description: 'Pre-release tag name' + required: false + type: string + + push: + branches: + - master + tags: + - 'v*' + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +permissions: + contents: write # for creating release + +jobs: + determine-tag: + runs-on: ubuntu-latest + outputs: + tag_name: ${{ steps.tag.outputs.name }} + should_release: ${{ steps.tag.outputs.should_release }} + + steps: + - name: Checkout with full history + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" + SHOULD_RELEASE="false" + + echo "Raw values:" + echo "BUILD_NUMBER: $BUILD_NUMBER" + echo "SHORT_HASH: $SHORT_HASH" + echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" + echo "CUSTOM_TAG: $CUSTOM_TAG" + + if [[ "${{ github.ref_type }}" == "tag" ]]; then + echo "Using pushed tag name" + TAG_NAME="${{ github.ref_name }}" + SHOULD_RELEASE="true" + elif [[ -n "$CUSTOM_TAG" ]]; then + echo "Using custom tag" + TAG_NAME="${CUSTOM_TAG}" + SHOULD_RELEASE="true" + elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then + echo "Manual release requested" + SHOULD_RELEASE="true" + TAG_NAME="b${BUILD_NUMBER}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "Using master branch format" + TAG_NAME="b${BUILD_NUMBER}" + SHOULD_RELEASE="false" + else + echo "Using non-master branch format" + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + SHOULD_RELEASE="false" + fi + + echo "Final tag name: $TAG_NAME" + echo "Should release: $SHOULD_RELEASE" + echo "name=$TAG_NAME" >> $GITHUB_OUTPUT + echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT + + ubuntu-cpu: + runs-on: ${{ matrix.os }} + needs: determine-tag + if: ${{ needs.determine-tag.outputs.should_release == 'true' }} + + strategy: + matrix: + include: + - build: x64 + os: ubuntu-22.04 + - build: arm64 + os: ubuntu-22.04-arm + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release-${{ matrix.os }}-cpu + evict-old-files: 1d + + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake + + - name: Build + run: | + cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_NATIVE=OFF \ + ${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }} + cmake --build build --config Release -j $(nproc) + + - name: Pack artifacts + run: | + cp LICENSE ./build/bin/ + tar -czvf whisper-bin-ubuntu-${{ matrix.build }}.tar.gz \ + --transform "s,^\.,whisper-bin-ubuntu-${{ matrix.build }}," \ + -C ./build/bin . + + - name: Upload artifacts + uses: actions/upload-artifact@v6 + with: + path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + + windows: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + sdl2: [ON] + include: + - arch: Win32 + s2arc: x86 + jnaPath: win32-x86 + - arch: x64 + s2arc: x64 + jnaPath: win32-x86-64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DBUILD_SHARED_LIBS=ON + -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Upload SDL2.dll + if: matrix.sdl2 == 'ON' + uses: actions/upload-artifact@v6 + with: + name: ${{ matrix.s2arc }}_SDL2.dll + path: build/bin/${{ matrix.build }}/SDL2.dll + + - name: Upload whisper dll + uses: actions/upload-artifact@v6 + with: + name: whisper_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/whisper.dll + + - name: Upload ggml dll + uses: actions/upload-artifact@v6 + with: + name: ggml_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true + + - name: Upload ggml base dll + uses: actions/upload-artifact@v6 + with: + name: ggml_base_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-base.dll + + - name: Upload ggml cpu dll + uses: actions/upload-artifact@v6 + with: + name: ggml_cpu_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-cpu.dll + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-bin-${{ matrix.arch }}.zip + path: whisper-bin-${{ matrix.arch }}.zip + + windows-blas: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + blas: [ON] + sdl2: [ON] + blasver: [0.3.29] + include: + - arch: Win32 + s2arc: x86 + blasfile: x86 + - arch: x64 + s2arc: x64 + blasfile: x64_64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v8 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Install OpenBLAS and pkgconfiglite + if: matrix.blas == 'ON' + run: | + Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" + Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" + choco install pkgconfiglite + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DGGML_BLAS=${{ matrix.blas }} + -DGGML_BLAS_VENDOR=OpenBLAS + -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" + -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" + -DWHISPER_SDL2=${{ matrix.sdl2 }} + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy openblas.dll + if: matrix.blas == 'ON' + run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-blas-bin-${{ matrix.arch }}.zip + path: whisper-blas-bin-${{ matrix.arch }}.zip + + windows-cublas: + runs-on: windows-2022 + needs: determine-tag + strategy: + fail-fast: false + matrix: + build: [Release] + arch: [x64] + cublas: [ON] + sdl2: [ON] + cuda-toolkit: [12.4.0, 11.8.0] + include: + - arch: x64 + sdl2: ON + sdl2_ver: 2.28.5 + steps: + - name: Clone repository + uses: actions/checkout@v6 + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} + variant: sccache + evict-old-files: 5d + + - name: Install Cuda Toolkit 11.8.0 + if: ${{ matrix.cuda-toolkit == '11.8.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "11.8.89" + $NVCC_VER = "11.8.89" + $NVRTC_VER = "11.8.89" + $CUBLAS_VER = "11.8.1.74" + $NVTX_VER = "11.8.86" + $VS_VER = "11.8.86" + $NVPROF_VER = "11.8.87" + $CCCL_VER = "11.8.89" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4.0 + if: ${{ matrix.cuda-toolkit == '12.4.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "12.4.127" + $NVCC_VER = "12.4.131" + $NVRTC_VER = "12.4.127" + $CUBLAS_VER = "12.4.5.8" + $NVTX_VER = "12.4.127" + $PROFILER_VER = "12.4.127" + $VS_VER = "12.4.127" + $NVPROF_VER = "12.4.128" + $CCCL_VER = "12.4.127" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Install 7-Zip + run: choco install 7zip -y + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip + 7z x sdl2.zip + echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt + + - name: Install cmake + run: choco install cmake + + - name: Build Project + shell: cmd + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + cmake --version + where cmake + if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR + ) else ( + set CUDA_FLAGS= + ) + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ + -DGGML_CUDA=${{ matrix.cublas }} ^ + -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ + -DSDL2_DIR="%SDL2_DIR%" ^ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ + -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% + + - name: Check sccache status after build + run: | + sccache --show-stats + + - name: Copy CUDA DLLs + run: | + Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | + Copy-Item -Destination "build/bin/${{ matrix.build }}" + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + + ios-xcode-build: + runs-on: macos-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Configure + run: | + cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin + mkdir models/ggml-base.en-encoder.mlmodelc + + - name: Build + id: cmake_build + run: | + sysctl -a + mkdir build + cd build + cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build objc example + run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Build swiftui example + run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Pack artifacts + id: pack_artifacts + run: | + zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework + + - name: Upload artifacts + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + + release: + if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} + + runs-on: ubuntu-latest + + needs: + - determine-tag + - ubuntu-cpu + - ios-xcode-build + - windows + - windows-blas + - windows-cublas + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release + evict-old-files: 1d + + # Downloads all the artifacts from the previous jobs + - name: Download artifacts + id: download-artifact + uses: actions/download-artifact@v7 + with: + path: ./artifact + + - name: Move artifacts + id: move_artifacts + run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release && mv ./artifact/*/*.tar.gz ./artifact/release 2>/dev/null || true + + - name: Create release + id: create_release + uses: ggml-org/action-create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ needs.determine-tag.outputs.tag_name }} + prerelease: ${{ github.event.inputs.pre_release_tag != '' }} + draft: true + + - name: Upload release + id: upload_release + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const path = require('path'); + const fs = require('fs'); + const release_id = '${{ steps.create_release.outputs.id }}'; + for (let file of await fs.readdirSync('./artifact/release')) { + if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { + console.log('uploadReleaseAsset', file); + await github.repos.uploadReleaseAsset({ + owner: context.repo.owner, + repo: context.repo.repo, + release_id: release_id, + name: file, + data: await fs.readFileSync(`./artifact/release/${file}`) + }); + } + } diff --git a/ci/run.sh b/ci/run.sh index b03fdf1c6b1..dca4476a0fa 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -151,8 +151,15 @@ function gg_download_model { local cwd=`pwd` mkdir -p "$MNT/models" cd "$MNT/models" + set -x bash "$cwd/models/download-ggml-model.sh" ${model_name} . + local download_status=$? + set +x cd "$cwd" + if [ $download_status -ne 0 ]; then + echo "Error: failed to download model ${model_name}" + ret=1 + fi fi } diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index f1394e98484..0539c8afb3d 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -120,7 +120,13 @@ fi if [ -x "$(command -v wget2)" ]; then wget2 --no-config --progress bar -O ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v curl)" ]; then - curl -L --output ggml-"$model".bin $src/$pfx-"$model".bin + curl -L --fail \ + --retry 5 \ + --retry-delay 5 \ + --retry-all-errors \ + --retry-connrefused \ + ${HF_TOKEN:+--header "Authorization: Bearer $HF_TOKEN"} \ + --output ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v wget)" ]; then wget --no-config --quiet --show-progress -O ggml-"$model".bin $src/$pfx-"$model".bin else From 12d1828837f8ca3ea2b3c94180bcab733fb76092 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 10:30:48 +0200 Subject: [PATCH 205/289] ci : only publish/push docker images daily (#3854) This commit updates the docker workflow to be triggered on a schedule or manually. --- .github/workflows/docker.yml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 9e07f7b2292..51724976e0a 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,9 +1,10 @@ name: Publish Docker image on: - push: - branches: - - master + workflow_dispatch: # allows manual triggering + schedule: + # Rebuild daily rather than on every push because it is expensive + - cron: '12 4 * * *' jobs: push_to_registry: @@ -57,16 +58,14 @@ jobs: id: tags run: | TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}" - if [ "${{ github.event_name }}" == "push" ]; then - TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" - fi + TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) uses: docker/build-push-action@v6 with: context: . - push: ${{ github.event_name == 'push' }} + push: true platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} From 9302c060f0d8178a01aa6b36e9673032fbc9aff8 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 11:37:22 +0200 Subject: [PATCH 206/289] ci : use ccache instead of sccache for windows-cublas [no ci] (#3855) This commit updates the Install cache step to use ggml-org/ccache-action and switched to use ccache instead of sccache. The motivation for switching to ccache is that this is what llama.cpp does and also there is an issue with later version of sscache: ```console sccache C:\PROGRA~1\NVIDIA~1\CUDA\v\bin\nvcc.exe -forward-unknown-to-host-compiler -DGGML_BACKEND_BUILD -DGGML_BACKEND_SHARED -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_SCHED_MAX_COPIES=4 -DGGML_SHARED -D_CRT_SECURE_NO_WARNINGS -D_XOPEN_SOURCE=600 -Dggml_cuda_EXPORTS -DCMAKE_INTDIR=\"Release\" -ID:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\.. -ID:\a\whisper.cpp\whisper.cpp\ggml\src\..\include -isystem "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v\include" -Xcompiler="-MD -O2 -Ob2" -DNDEBUG -std=c++17 -arch=native -use_fast_math -extended-lambda -Xcompiler /Zc:preprocessor -MD -MT ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -MF ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj.d -x cu -c D:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\allreduce.cu -o ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -Xcompiler=-Fdggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\,-FS sccache: encountered fatal error sccache: error: Could not parse shell line sccache: caused by: Could not parse shell line ``` ``` --- .github/workflows/release.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2ba8b45093b..c3ae9de4deb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -340,10 +340,9 @@ jobs: choco install ninja - name: Install ccache - uses: hendrikmuhs/ccache-action@v1.2.16 + uses: ggml-org/ccache-action@v1.2.21 with: key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - variant: sccache evict-old-files: 5d - name: Install Cuda Toolkit 11.8.0 @@ -497,9 +496,9 @@ jobs: set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - name: Check sccache status after build + - name: Check ccache status after build run: | - sccache --show-stats + ccache --show-stats - name: Copy CUDA DLLs run: | From 7ecb08f26359708dbc7fbeea428916684c64a76e Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 11:38:46 +0200 Subject: [PATCH 207/289] ci : pin github actions to commit SHAs (#3856) This commit pins github actions used to the same commi SHAs that llama.cpp uses. --- .github/workflows/build-android.yml | 4 ++-- .github/workflows/build-clang.yml | 2 +- .github/workflows/build-gcc.yml | 4 ++-- .github/workflows/build-windows.yml | 2 +- .github/workflows/docker.yml | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml index d9af1810131..42673166cf3 100644 --- a/.github/workflows/build-android.yml +++ b/.github/workflows/build-android.yml @@ -41,7 +41,7 @@ jobs: java-version: 21 - name: Setup Android SDK - uses: android-actions/setup-android@v3 + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 - name: Build run: | @@ -69,7 +69,7 @@ jobs: cache: gradle - name: Setup Android SDK - uses: android-actions/setup-android@v3 + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 with: cmdline-tools-version: 9.0 diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml index c7a36884f64..5308164cc68 100644 --- a/.github/workflows/build-clang.yml +++ b/.github/workflows/build-clang.yml @@ -61,7 +61,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index 4528ba3d534..b1b04c24034 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -58,7 +58,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | @@ -141,7 +141,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index cd1591f0132..9fd910ac0ec 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -46,7 +46,7 @@ jobs: uses: actions/checkout@v6 - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@v2 + uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 with: update: true msystem: ${{matrix.sys}} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 51724976e0a..e7ca8595ddd 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -30,10 +30,10 @@ jobs: uses: actions/checkout@v6 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Log in to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -62,7 +62,7 @@ jobs: echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7 with: context: . push: true From ad17783d3499d54bd64f8afd19932ea7b0d5d175 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 14:25:15 +0200 Subject: [PATCH 208/289] ci : use emscripten-core and pin version (#3857) This commit updates the setup emscripten sdk jobs to use emscripten-core instead of mymindstorm and also pins the commit sha for the version instead of using a version tag. --- .github/workflows/build-wasm.yml | 2 +- .../workflows/{examples-wasm.yml => deploy-examples-wasm.yml} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{examples-wasm.yml => deploy-examples-wasm.yml} (96%) diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index 42a9401af3c..d2891eda90f 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -40,7 +40,7 @@ jobs: uses: actions/checkout@v6 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - name: Verify run: emcc -v diff --git a/.github/workflows/examples-wasm.yml b/.github/workflows/deploy-examples-wasm.yml similarity index 96% rename from .github/workflows/examples-wasm.yml rename to .github/workflows/deploy-examples-wasm.yml index 927438cdad8..e7fdae77854 100644 --- a/.github/workflows/examples-wasm.yml +++ b/.github/workflows/deploy-examples-wasm.yml @@ -28,7 +28,7 @@ jobs: uses: actions/configure-pages@v5 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - name: Build WASM Examples # Enable for real build later in whisper.cpp From 99613cb720b65036237d44b52f753b51f75c2797 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 16:27:58 +0200 Subject: [PATCH 209/289] ci: build-windows action slimming (#3858) * ci : remove base-devel and git from msys2 job This commit removes the above packages as they might not be required and could help reduce the github cache size. * ci : try reducing the installs to only the compilers --- .github/workflows/build-windows.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 9fd910ac0ec..156a57f74b6 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -38,8 +38,8 @@ jobs: fail-fast: false matrix: include: - - { sys: UCRT64, env: ucrt-x86_64, build: Release } - - { sys: CLANG64, env: clang-x86_64, build: Release } + - { sys: UCRT64, env: ucrt-x86_64, compiler: gcc, build: Release } + - { sys: CLANG64, env: clang-x86_64, compiler: clang, build: Release } steps: - name: Clone @@ -51,9 +51,7 @@ jobs: update: true msystem: ${{matrix.sys}} install: >- - base-devel - git - mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-${{matrix.compiler}} mingw-w64-${{matrix.env}}-cmake mingw-w64-${{matrix.env}}-SDL2 mingw-w64-${{matrix.env}}-openblas From 574fc0da69bcf2da3262e40d1b4009341df3d53f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Sat, 6 Jun 2026 05:40:58 +0200 Subject: [PATCH 210/289] ci : add ccache to quantize, vad, and wasm jobs (#3860) * ci : add ccache to build-quantize * ci : add ccache to build-vad * ci : add ccache to build-wasm [no ci] --- .github/workflows/build-quantize.yml | 9 ++++++++- .github/workflows/build-vad.yml | 9 ++++++++- .github/workflows/build-wasm.yml | 18 ++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml index 8036a3a3450..69ab2c34638 100644 --- a/.github/workflows/build-quantize.yml +++ b/.github/workflows/build-quantize.yml @@ -31,11 +31,18 @@ jobs: - name: Clone uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: quantize-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Test quantize env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | ./models/download-ggml-model.sh tiny.en - cmake -B build + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build --config Release ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml index 71e910a3fcb..3c5ebec2026 100644 --- a/.github/workflows/build-vad.yml +++ b/.github/workflows/build-vad.yml @@ -31,10 +31,17 @@ jobs: - name: Checkout uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: vad-ubuntu-latest + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Build shell: bash run: | - cmake -B build + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build --config Release - name: Test diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index d2891eda90f..45c77c0be4c 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -45,7 +45,21 @@ jobs: - name: Verify run: emcc -v + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: wasm-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Build + env: + CCACHE_SLOPPINESS: time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content run: | - emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make + emcmake cmake -B build -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + "-DCMAKE_C_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" \ + "-DCMAKE_CXX_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" + cmake --build build -j $(nproc) From a8ec021f2750a473ff4a8f3883bc9fdf5feafa84 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Sat, 6 Jun 2026 18:34:40 +0200 Subject: [PATCH 211/289] ci : add HF_TOKEN to docker.yml workflow [no ci] (#3861) This commit adds the HF_TOKEN secret to the docker workflows to avoid HF rate limiting which currently sometimes causes the jobs to fail. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/27053852601/job/79854251771 --- .devops/main-cuda.Dockerfile | 2 +- .devops/main-intel.Dockerfile | 3 ++- .devops/main-musa.Dockerfile | 2 +- .devops/main-vulkan.Dockerfile | 2 +- .devops/main.Dockerfile | 2 +- .github/workflows/docker.yml | 2 ++ 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index c2bf0fbd1c6..7a21fc4e3db 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -25,7 +25,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH COPY .. . # Enable cuBLAS -RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 86b901c1538..a0c04ad34ad 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && \ COPY .. . # Enable SYCL ARG GGML_SYCL_F16=OFF -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN \ + if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ diff --git a/.devops/main-musa.Dockerfile b/.devops/main-musa.Dockerfile index 026791e3f89..c68367830f1 100644 --- a/.devops/main-musa.Dockerfile +++ b/.devops/main-musa.Dockerfile @@ -16,7 +16,7 @@ RUN apt-get update && \ COPY .. . # Enable muBLAS -RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_MUSA=1" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 077af4f1001..16ee19dc689 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" FROM ubuntu:24.04 AS runtime WORKDIR /app diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index e1eb9b33700..d0e809f4e13 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en FROM ubuntu:22.04 AS runtime WORKDIR /app diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index e7ca8595ddd..b4c455b92e9 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -69,3 +69,5 @@ jobs: platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} + secrets: | + HF_TOKEN=${{ secrets.HF_TOKEN }} From e1da83d7736f4a170a4c8057c205df35c39fe230 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Mon, 8 Jun 2026 07:27:12 +0200 Subject: [PATCH 212/289] ci : add ccache to build-sycl [no ci] (#3859) --- .github/workflows/build-sycl.yml | 40 +++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml index 57aa7cc4d95..c76954e49cf 100644 --- a/.github/workflows/build-sycl.yml +++ b/.github/workflows/build-sycl.yml @@ -61,24 +61,33 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git + sudo apt install intel-oneapi-compiler-dpcpp-cpp - name: install oneAPI MKL library shell: bash run: | - sudo apt install intel-oneapi-mkl-devel git + sudo apt install intel-oneapi-mkl-devel - - name: Clone - id: checkout - uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Build id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 run: | source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. cmake --build . --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: @@ -111,22 +120,31 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git + sudo apt install intel-oneapi-compiler-dpcpp-cpp - name: install oneAPI MKL library shell: bash run: | sudo apt install intel-oneapi-mkl-devel - - name: Clone - id: checkout - uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-fp16-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Build id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 run: | source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" mkdir build cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. cmake --build . --config Release -j $(nproc) From c50e951afdf0b1bd4d63adddbd48dc90ff92893c Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Fri, 29 May 2026 10:15:17 +0200 Subject: [PATCH 213/289] model : support for DeepseekV32ForCausalLM with generic DeepSeek Sparse Attention (DSA) implementation (llama/23346) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * llama : support DeepSeek V3.2 model family (with DSA lightning indexer) * convert : handle DeepseekV32ForCausalLM architecture * ggml : support for f16 GGML_OP_FILL * memory : separate hparams argument in llama_kv_cache constructor * memory : add llama_kv_cache_dsa memory (KV cache + lightning indexer cache) * llama : support for LLM_ARCH_DEEPSEEK32 * model : llama_model_deepseek32 implementation * model : merge two scale operations into one in DSA lightning indexer implementation * chore : remove unused code * model : support NVFP4 in DeepSeek V3.2 Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * memory : refactoring TODO Co-authored-by: ggerganov <ggerganov@users.noreply.github.com> --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: ggerganov <ggerganov@users.noreply.github.com> --- ggml/src/ggml-cpu/ops.cpp | 36 +++++++++++++++++++++++++++++++++++- ggml/src/ggml.c | 2 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7485ba4fc86..dc73696ad9f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2235,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 476c3079795..8815c67d8bc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5223,7 +5223,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); From f7aad4ed7e6818cebdbe87fd78a28f02f6cefedb Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Fri, 29 May 2026 12:28:18 +0200 Subject: [PATCH 214/289] CUDA: Check PTX version on host side to guard PDL dispatch (llama/23530) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: Check PTX version on host side to guard PDL dispatch Checking on `__CUDA_ARCH_LIST__` alone is insufficient for JIT, as this variable doesn't differentiate between compiling for say sm_90, sm_90a or sm_90f (so forward-jittable PTX vs. arch/family-specific PTX). Thus, one can have a bug when compiling with `DCMAKE_CUDA_ARCHITECTURES="89;90a"`, where current code would wrongly dispatch to PDL on sm_90/sm_120 in forward-JIT mode. This PR fixes this issue by checking `cudaFuncAttributes::ptxVersion` of the incoming kernel at runtime. A check on ptxVersion alone is sufficient, as device-codes will always be >= ptxVersion (and any violation of this would be a severe bug in CUDA/nvcc), see: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#gpu-code-code-code * Implement MurmurHash3 mixer for better hash distribution Magic constants were taken from boost: https://github.com/boostorg/container_hash/blob/2698b43803c012601e6bb1a6116e83767b97986c/include/boost/container_hash/detail/hash_mix.hpp#L19-L65 * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Address review comments, make seed non-zero * Apply code-formatting * Replace std::size_t -> size_t for consistency --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --- ggml/src/ggml-cuda/common.cuh | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 50d7763dcdd..560fab0b17b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -7,6 +7,7 @@ #include <cstdint> #include <cstdlib> #include <memory> +#include <mutex> #if defined(GGML_USE_HIP) #define GGML_COMMON_DECL_HIP @@ -1552,6 +1553,62 @@ struct ggml_cuda_pdl_config { ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; }; + +static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { + const int device = ggml_cuda_get_device(); + + struct cache_key { + int device; + const void * kernel; + + bool operator==(const cache_key & other) const { return device == other.device && kernel == other.kernel; } + }; + + struct cache_key_hash { + // MurmurHash3 mixing function for better hash distribution (vs. just std::hash which in some implementations simply returns the identity) + static size_t hash_mix(size_t x) { + std::uint64_t y = x; + const std::uint64_t m = 0xe9846af9b1a615d; + + y ^= y >> 32; + y *= m; + y ^= y >> 32; + y *= m; + y ^= y >> 28; + + return static_cast<size_t>(y); + } + + size_t operator()(const cache_key & key) const { + // Use a nonzero seed to avoid mapping all-zero keys to zero + size_t h = 42; + h = hash_mix(h + key.device); + h = hash_mix(h + reinterpret_cast<size_t>(key.kernel)); + return h; + } + }; + + static std::mutex cache_mutex; + static std::unordered_map<cache_key, bool, cache_key_hash> cache; + + const cache_key key = { device, kernel }; + std::lock_guard<std::mutex> lock(cache_mutex); + const auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + cudaFuncAttributes attr = {}; + CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); + + // PDL device-side primitives are emitted only for PTX versions >= 90. + // We have to guard on a loaded kernel's PTX version so a kernel forward-JIT'ed + // from pre-Hopper PTX to a Hopper-or-newer GPU does not opt into PDL. + const bool can_use_pdl = attr.ptxVersion >= 90; + cache.emplace(key, can_use_pdl); + return can_use_pdl; +} + #endif //defined(GGML_CUDA_USE_PDL) @@ -1564,8 +1621,7 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { + if (env_pdl_enabled && ggml_cuda_kernel_can_use_pdl(reinterpret_cast<const void *>(kernel))) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); From acd91d2c3891ac8f2538152882552657691ed6af Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Fri, 29 May 2026 14:14:11 -0700 Subject: [PATCH 215/289] ggml-webgpu: add q4_0/q8_0 SET_ROWS (llama/23760) * Add q8_0 and q4_0 set_rows * Add fast(er) quantization set_rows path * formatting/naming * a little more naming * Remove unused constant * Don't override other override * Avoid bitcast * Narrow relaxation --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 90 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +- .../ggml-webgpu/wgsl-shaders/set_rows.wgsl | 5 +- .../wgsl-shaders/set_rows_quant.wgsl | 224 ++++++++++++++++++ 4 files changed, 289 insertions(+), 41 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 60e98a60741..f4c5eca0df5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context { ggml_tensor * src5; ggml_tensor * dst; - uint32_t max_wg_size; - size_t wg_mem_limit_bytes = 0; - bool supports_subgroups = false; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m = 0; - uint32_t sg_mat_n = 0; - uint32_t sg_mat_k = 0; - uint32_t min_subgroup_size = 0; - uint32_t max_subgroup_size = 0; - bool supports_dot_product = false; + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; std::string vendor; }; @@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key { int dst_type; int vec4; int i64_idx; + int pair_blocks; bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && + pair_blocks == other.pair_blocks; } }; @@ -178,6 +180,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.vec4); ggml_webgpu_hash_combine(seed, key.i64_idx); + ggml_webgpu_hash_combine(seed, key.pair_blocks); return seed; } }; @@ -185,6 +188,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { struct ggml_webgpu_set_rows_shader_decisions { bool vec4; bool i64_idx; + bool pair_blocks; uint32_t wg_size; }; @@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && - context.src2->ne[0] % kv_vec_head_align == 0; + const uint32_t kv_vec_head_align = + K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = + context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && - kv_vec_head_dims_aligned && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = - context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; + const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && + context.src0->ne[0] % context.sg_mat_k == 0 && + context.src2->ne[0] % context.sg_mat_n == 0; const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && tile_can_dispatch_all_q_rows && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { return decisions; @@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash> - mul_mat_vec_pipelines; // fast mat-vec (n==1) + mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash> - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash> quantize_q8_pipelines; std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed @@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = {}; - key.dst_type = context.dst->type; - key.vec4 = context.src0->ne[0] % 4 == 0; - key.i64_idx = context.src1->type == GGML_TYPE_I64; + const bool quantized = ggml_is_quantized(context.dst->type); + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = + (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; + key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_F16"); variant += "_dstf16"; break; + case GGML_TYPE_Q8_0: + defines.push_back("DST_Q8_0"); + variant += "_dstq8_0"; + break; + case GGML_TYPE_Q4_0: + defines.push_back("DST_Q4_0"); + variant += "_dstq4_0"; + break; default: GGML_ABORT("Unsupported dst type for set_rows shader"); } @@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib { defines.push_back("I64_IDX"); variant += "_i64idx"; } + if (key.pair_blocks) { + defines.push_back("PAIR_BLOCKS"); + variant += "_pair_blocks"; + } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_set_rows, defines); - auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); + const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; + auto processed = preprocessor.preprocess(shader_source, defines); + auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); decisions->vec4 = key.vec4; decisions->i64_idx = key.i64_idx; + decisions->pair_blocks = key.pair_blocks; decisions->wg_size = context.max_wg_size; set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); set_rows_pipelines[key].context = decisions; @@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; - key.use_mmvq = + key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1846886db4e..1a99f1cb52f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1331,7 +1331,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct } uint32_t threads; - if (decisions->vec4) { + if (ggml_is_quantized(dst->type)) { + const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type); + threads = + (src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row); + } else if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -4046,8 +4050,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); break; case GGML_OP_SET_ROWS: - supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 || + op->type == GGML_TYPE_Q4_0) && + src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 99e9192c71a..09f2f0eddb3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { return; } - // getting the row from gid let elems_per_row = params.ne0 / VEC_SIZE; var i = gid.x / elems_per_row; @@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); + let col_idx = gid.x % elems_per_row; + dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl new file mode 100644 index 00000000000..876e65b6ae1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl @@ -0,0 +1,224 @@ +#ifdef DST_Q8_0 +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 34u +#define QS_WORDS 8u +#elif defined(DST_Q4_0) +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 18u +#define QS_WORDS 4u +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> idx: array<u32>; + +@group(0) @binding(2) +#ifdef PAIR_BLOCKS +var<storage, read_write> dst: array<u32>; +#else +var<storage, read_write> dst: array<atomic<u32>>; +#endif + +#ifdef I64_IDX +@group(0) @binding(3) +var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in blocks + + // Strides (in elements / blocks) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically +#ifndef PAIR_BLOCKS +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + loop { + let old = atomicLoad(&dst[word_idx]); + let merged = (old & ~mask) | (bits & mask); + let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged); + if (result.exchanged) { + return; + } + } +} +#else +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + let old = dst[word_idx]; + dst[word_idx] = (old & ~mask) | (bits & mask); +} +#endif + +fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 2u) * 8u; + let mask = 0xFFFFu << shift; + merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift); +} + +fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 3u) * 8u; + + if (shift == 0u) { +#ifdef PAIR_BLOCKS + dst[word_idx] = value; +#else + atomicStore(&dst[word_idx], value); +#endif + return; + } + + let lo_mask = 0xFFFFFFFFu << shift; + let hi_mask = (1u << shift) - 1u; + merge_store_dst_word(word_idx, lo_mask, value << shift); + merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift)); +} + +fn quantize_block_params(src_block: u32) -> vec2<f32> { +#ifdef DST_Q8_0 + var amax = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + amax = max(amax, abs(src[src_block + j])); + } + + let d = amax / 127.0; + let id = select(0.0, 1.0 / d, d > 0.0); + return vec2(d, id); +#elif defined(DST_Q4_0) + var amax = 0.0; + var max_val = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + let v = src[src_block + j]; + let av = abs(v); + if (amax < av) { + amax = av; + max_val = v; + } + } + + let d = max_val / -8.0; + let id = select(0.0, 1.0 / d, d != 0.0); + return vec2(d, id); +#endif +} + +fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 { +#ifdef DST_Q8_0 + let base = src_block + j * 4u; + return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) | + (u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) | + (u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) | + (u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u); +#elif defined(DST_Q4_0) + var packed_q = 0u; + for (var k: u32 = 0u; k < 4u; k++) { + let x0 = src[src_block + j * 4u + k] * id; + let x1 = src[src_block + 16u + j * 4u + k] * id; + let q0 = u32(clamp(i32(x0 + 8.5), 0, 15)); + let q1 = u32(clamp(i32(x1 + 8.5), 0, 15)); + packed_q |= (q0 & 0xFu) << (8u * k); + packed_q |= (q1 & 0xFu) << (8u * k + 4u); + } + return packed_q; +#endif +} + +fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) { + let params = quantize_block_params(src_block); + let d = params.x; + let id = params.y; + let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu; + store_u16(dst_word_idx, block_byte_offset, 0u, packed_d); + + for (var j: u32 = 0u; j < QS_WORDS; j++) { + store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id)); + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + let blocks_per_row = params.ne0 / BLOCK_SIZE; +#ifdef PAIR_BLOCKS + let blocks_per_invocation = 2u; +#else + let blocks_per_invocation = 1u; +#endif + let invocations_per_row = blocks_per_row / blocks_per_invocation; + let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row; + if (gid.x >= total_invocations) { + return; + } + + var i = gid.x / invocations_per_row; + let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation; + + let i_src3 = i / (params.ne2 * params.n_rows); + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u; + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1u]; + + if (idx_low_val != 0u) { + atomicStore(&error, 1u); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + let src_block = src_row + block_in_row * BLOCK_SIZE; + let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES; + + let dst_word_idx = dst_block_byte / 4u; +#ifdef PAIR_BLOCKS + quantize_block(src_block, dst_word_idx, 0u); + quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES); +#else + quantize_block(src_block, dst_word_idx, dst_block_byte & 3u); +#endif +} From 9147a9676b9945920088e90ee703d2ff462ded8b Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Fri, 29 May 2026 14:16:05 -0700 Subject: [PATCH 216/289] ggml-webgpu: Check earlier for WebGPU required features (llama/23879) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1a99f1cb52f..d577b5afa3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); ctx->webgpu_global_ctx->vendor = info.vendor; - wgpu::SupportedFeatures features; - ctx->webgpu_global_ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); // for dot4I8packed @@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "device_desc: %s\n", info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, std::string(info.device).c_str(), std::string(info.description).c_str()); - return true; } static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { @@ -4507,7 +4502,12 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { UINT64_MAX); } - if (adapter != nullptr) { + // WebGPU backend requires f16 support and, on native, implicit device synchronization. + if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16) +#ifndef __EMSCRIPTEN__ + && adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization) +#endif + ) { ctx->device_count = 1; } @@ -4515,8 +4515,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - + ggml_backend_reg_t reg = ggml_backend_webgpu_reg(); + if (ggml_backend_reg_dev_count(reg) == 0) { + return nullptr; + } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0); return ggml_backend_webgpu_backend_init(dev, nullptr); } From 4317ddbe2b0fa7f436593658c8252469e598b36a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Sat, 30 May 2026 10:39:31 +0200 Subject: [PATCH 217/289] vulkan: add Flash Attention support for BFloat16 KV cache (llama/23420) * vulkan: add flash attention bf16 kv support * vulkan: bf16 FA coopmat1 support * vulkan: bf16 FA coopmat2 support * fix FA bf16 f32 fallback * fix FA bf16 coopmat1 shader * fix FA bf16 coopmat2 shader * code cleanup * cleanup comment change * address feedback * add O_TYPE for cm2 FA * use O_TYPE for gqaStore function * reduce BFLOAT16 ifdefs --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 148 +++++++++++++----- .../vulkan-shaders/flash_attn_base.glsl | 12 +- .../vulkan-shaders/flash_attn_cm1.comp | 98 +++++++----- .../vulkan-shaders/flash_attn_cm2.comp | 36 +++-- .../vulkan-shaders/flash_attn_dequant.glsl | 8 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 22 +++ 6 files changed, 235 insertions(+), 89 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c9f906d7930..2a30fb95c61 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -691,6 +691,7 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_bf16_support {}; bool coopmat2_decode_vector; bool pipeline_executable_properties_support {}; @@ -3139,7 +3140,7 @@ struct vk_fa_tuning_params { }; static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { @@ -3279,6 +3280,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) { + path = FA_COOPMAT1; + } + if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) { + path = FA_SCALAR; + } + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 path = FA_SCALAR; @@ -3288,7 +3296,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type); if (!shape_ok || !shmem_ok) { path = FA_SCALAR; @@ -3334,8 +3342,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) { const auto fa_block_bytes = [](ggml_type t) -> uint32_t { - // decodeBufF32 uses a block of vec4s for a better memory access pattern. - return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + if (t == GGML_TYPE_F32) return 16u; + return (uint32_t) ggml_type_size(t); }; return { /* 0 WorkGroupSize */ state.workgroup_size, @@ -3849,10 +3857,16 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); const void * spv_data = nullptr; size_t spv_size = 0; - if (use_mmq) { + const char *name = nullptr; + if (bf16_kv) { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16"; + } else if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } @@ -3862,6 +3876,7 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_size = flash_attn_f32_f16_fp32_int8_len; } #endif + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } @@ -3870,8 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; } + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } - const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3889,11 +3904,25 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const void * spv_data; size_t spv_size; - if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } - else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } - const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + const char *name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm1_data; + spv_size = flash_attn_f32_f16_bf16_cm1_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1"; +#else + continue; +#endif + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + } ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3911,10 +3940,20 @@ static void ggml_vk_load_shaders(vk_device& device) { const bool aligned = fa.first.aligned; const bool f32acc = fa.first.f32acc; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const void * spv_data; size_t spv_size; const char * name; - if (aligned) { + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat2_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm2_data; + spv_size = flash_attn_f32_f16_bf16_cm2_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2"; +#else + continue; +#endif + } else if (aligned) { if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } } else { @@ -5784,46 +5823,72 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp16_256 = false, found_fp32_128 = false, found_fp32_256 = false; + bool found_bf16_128 = false, + found_bf16_256 = false; // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 // with 32x16x16 and 256 with 32x32x16. for (auto &prop : flexible_dimensions) { if (prop.saturatingAccumulation == VK_FALSE && - prop.scope == VK_SCOPE_WORKGROUP_KHR && - prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - - if (prop.workgroupInvocations == 128 && - prop.MGranularity <= 32 && - prop.NGranularity <= 16 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_128 = true; + prop.scope == VK_SCOPE_WORKGROUP_KHR) { + + if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_128 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } } } - if (prop.workgroupInvocations == 256 && - prop.MGranularity <= 32 && - prop.NGranularity <= 32 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_256 = true; + +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + found_bf16_128 = true; } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_256 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + found_bf16_256 = true; } } +#endif } } if (found_fp16_128 && found_fp16_256 && found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256; device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } @@ -9448,7 +9513,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + // BF16 uses the fp32 shader (FLOAT_TYPE=float) + const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float); const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); @@ -9489,7 +9555,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) { // Needs to be kept up to date on shader changes const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; @@ -9519,8 +9585,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t vsh_stride = MatBc / 4 * row_split; const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes) + const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4; const uint32_t osh_stride = params.row_split * MatBr / 4; - const uint32_t pvsh = MatBc * osh_stride * f16vec4; + const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size; const uint32_t slope = Br * acctype; @@ -9589,7 +9657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -16400,6 +16468,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (t) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_0: @@ -16415,6 +16484,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } + if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) { + return false; + } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9a7957da97b..66dcf610219 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define FA_TYPE_Q5_0 6u #define FA_TYPE_Q5_1 7u #define FA_TYPE_Q8_0 8u +#define FA_TYPE_BF16 30u #define FA_TYPE_Q1_0 41u +#if defined(BFLOAT16) +#define O_TYPE float +#define O_TYPEV4 vec4 +#else +#define O_TYPE FLOAT_TYPE +#define O_TYPEV4 FLOAT_TYPEV4 +#endif + // Number of matrix elements per buffer block, derived from the K/V type spec // constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 // and bypasses the dequant path entirely. Quants follow their ggml block sizes. @@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) { case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_BF16: return 1u; case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere default: return 1u; } @@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index bffcc095be3..23ae3833e52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -6,6 +6,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -14,7 +18,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "flash_attn_dequant.glsl" +#endif // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; -layout (binding = 1) readonly buffer K {float16_t data_k[];}; -layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; -layout (binding = 2) readonly buffer V {float16_t data_v[];}; -layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];}; +layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];}; +layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];}; +layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; shared float tmpsh[row_split]; -const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 Qf[Br * qstride]; +const uint32_t qstride = HSK_pad / 4 + 2; +shared FLOAT_TYPEV4 Qf[Br * qstride]; const uint psh_stride = Br / 4 + 2; -shared f16vec4 Psh[Bc * psh_stride]; +shared FLOAT_TYPEV4 Psh[Bc * psh_stride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; -const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; +shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; const uint32_t osh_stride = row_split * MatBr / 4; -shared f16vec4 pvsh[MatBc * osh_stride]; +shared O_TYPEV4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -76,7 +82,7 @@ void main() { if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { - Qf[i + tid] = f16vec4(0); + Qf[i + tid] = FLOAT_TYPEV4(0); } } barrier(); @@ -89,15 +95,15 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - f16vec4 Of[rows_per_thread][d_per_thread]; + O_TYPEV4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = f16vec4(0.0); + Of[r][d] = O_TYPEV4(0.0); } } @@ -222,15 +228,18 @@ void main() { uint32_t d = (idx + tid) % (HSK_pad / 4); uint32_t c = (idx + tid) / (HSK_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); } } @@ -244,16 +253,16 @@ void main() { // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; - coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem - // If not, f16 K is loaded directly from global memory if aligned, otherwise + // If not, K is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If K is not type f16, then it is always staged for dequantization. + // If K is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always need to dequant into kvsh; for f16 we can load + // For quants we always need to dequant into kvsh; for f16/bf16 we can load // directly from global memory when alignment / bounds allow it. const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; if (stage_k) { @@ -262,15 +271,18 @@ void main() { uint32_t col_vec = (idx + tid) % (MatBr / 4); uint32_t row = (idx + tid) / (MatBr / 4); if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); } } @@ -357,7 +369,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local]; } } @@ -368,10 +380,10 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { const uint row = tile_row(r); if (KV_bounds_check && j * Bc + col >= KV) { - Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f); } else { const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); - const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { Lf[r + vec_idx] += Pf[vec_idx]; } @@ -385,15 +397,18 @@ void main() { uint32_t d = (idx + tid) % (HSV_pad / 4); uint32_t c = (idx + tid) / (HSV_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { - f16vec4 V_Tf = f16vec4(0); + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; uint ib = coord / BLOCK_SIZE_V; uint iqs = (coord % BLOCK_SIZE_V); V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } else +#endif + { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); } } @@ -409,7 +424,7 @@ void main() { [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; @@ -417,11 +432,11 @@ void main() { const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. - // If not, f16 V is loaded directly from global memory if aligned, otherwise + // If not, V is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If V is not type f16, then it is always staged for dequantization. + // If V is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always preload via kvsh. For f16 we only preload when + // For quants we always preload via kvsh. For f16/bf16 we only preload when // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). const bool stage_v = USE_DECODE_V || KV_bounds_check; if (stage_v) { @@ -438,13 +453,16 @@ void main() { const uint iqs = coord % BLOCK_SIZE_V; if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { + } else +#endif + { kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; } } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f); } } } @@ -459,7 +477,7 @@ void main() { if (SHMEM_STAGING == 0) { if (!USE_DECODE_V && !KV_bounds_check) { - // F16 values can be loaded directly from global memory + // F16/BF16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); @@ -573,7 +591,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= float16_t(ms); + Of[r][d_local] *= O_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -591,7 +609,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= float16_t(Lfrcp[r]); + Of[r][d_local] *= O_TYPE(Lfrcp[r]); #if defined(FLOAT_TYPE_MAX) Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6d45b4931df..b9c03fe499d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -8,6 +8,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable @@ -21,7 +25,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "dequant_funcs_cm2.glsl" +#endif // buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { @@ -31,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; +#if !defined(BFLOAT16) float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); @@ -91,6 +98,7 @@ f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], co #define FADECODEK , faDecodeK #define FADECODEV , faDecodeV #endif +#endif layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -195,15 +203,15 @@ void main() { tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q; - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - Qf16 *= float16_t(p.scale); + Q *= Q_TYPE(p.scale); + Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; @@ -291,16 +299,20 @@ void main() { coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. +#if defined(BFLOAT16) + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); +#else const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } +#endif S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -351,22 +363,26 @@ void main() { coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); // compute rowsum by multiplying by matrix of all ones. - coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; +#if defined(BFLOAT16) + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); +#else const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } +#endif L = eM*L + rowsum; @@ -378,7 +394,7 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); + O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); O = coopMatMulAdd(P_A, V, O); } @@ -427,7 +443,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= float16_t(ms); + O[i] *= O_TYPE(ms); } else { vs = exp(sink - Mr[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl index 02106f33cbe..8704479d960 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; +layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16; +layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16; + // Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 // views, used by the MMQ K-side hot path for fast 4-uint loads. layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; @@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ } +#define FA_DEQUANT4_BF16(BUF) \ + return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4]))); + FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { switch (FaTypeK) { @@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16) } } else { switch (FaTypeV) { @@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16) } } return FLOAT_TYPEV4(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fa9b938e4f7..de7dbec2c63 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -662,6 +662,28 @@ void process_shaders() { } } + const std::map<std::string, std::string> fa_bf16_dict = { + {"FLOAT_TYPE", "bfloat16_t"}, + {"FLOAT_TYPEV2", "bf16vec2"}, + {"FLOAT_TYPEV4", "bf16vec4"}, + {"ACC_TYPE", "float"}, + {"ACC_TYPEV2", "vec2"}, + {"ACC_TYPEV4", "vec4"}, + {"BFLOAT16", "1"}, + }; + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), + true, true, false, false); +#endif + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), + true, false, true, false); +#endif + std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { From 64b0d6b7fca11e62ef705127ec63fafdf0800a1e Mon Sep 17 00:00:00 2001 From: Jinyang He <hejinyang@loongson.cn> Date: Sat, 30 May 2026 16:53:26 +0800 Subject: [PATCH 218/289] ggml : add some lsx support (llama/23798) * loongarch : optimize LSX fp16 load/store with native intrinsics Use __lsx_vfcvtl_s_h and __lsx_vfcvt_h_s instead of scalar loops in __lsx_f16x4_load and __lsx_f16x4_store. * loongarch : add LSX implementation for q8_0 dot product * loongarch : add LSX implementation for q6_K dot product * loongarch : add LSX implementation for iq4_xs dot product * Improve reduce ops when sun int16 pairs to int32 --- ggml/src/ggml-cpu/arch/loongarch/quants.c | 151 ++++++++++++++++++++++ ggml/src/ggml-cpu/simd-mappings.h | 19 +-- 2 files changed, 154 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index 74e0c086c6d..9c43da6cf89 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -977,6 +977,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); *s = sumf; + +#elif defined(__loongarch_sx) + + __m128 acc = (__m128)__lsx_vldi(0); + + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const __m128i qx_0 = __lsx_vld((const __m128i *)x[ib].qs, 0); + const __m128i qx_1 = __lsx_vld((const __m128i *)x[ib].qs + 1, 0); + const __m128i qy_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + const __m128i qy_1 = __lsx_vld((const __m128i *)y[ib].qs + 1, 0); + + const __m128i p16_0 = lsx_maddubs_h(qx_0, qy_0); + const __m128i p16_1 = lsx_maddubs_h(qx_1, qy_1); + + // Sum int16 pairs → int32 + const __m128i s_0 = __lsx_vaddwev_w_h(p16_0, p16_1); + const __m128i s_1 = __lsx_vaddwod_w_h(p16_0, p16_1); + + const __m128 q = __lsx_vffint_s_w(__lsx_vadd_w(s_0, s_1)); + acc = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(d), q, acc); + } + + __m128 res = lsx_hadd_s(acc, acc); + res = lsx_hadd_s(res, res); + sumf = ((v4f32)res)[0]; + + *s = sumf; + #else UNUSED(nb); UNUSED(ib); @@ -1443,6 +1472,99 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = hsum_float_8(acc); +#elif defined(__loongarch_sx) + + const __m128i m32s = __lsx_vreplgr2vr_b(32); + + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scale_i8 = __lsx_vld(x[i].scales, 0); + const __m128i scales_lo = __lsx_vsllwil_h_b(scale_i8, 0); + const __m128i scales_hi = __lsx_vsllwil_h_b(__lsx_vbsrl_v(scale_i8, 8), 0); + + __m128i sumi_0 = __lsx_vldi(0); + __m128i sumi_1 = __lsx_vldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + const __m128i q4bitsH_1 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + + const __m128i q4h_0 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3), 4); + const __m128i q4h_1 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3), 4); + const __m128i q4h_2 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3 << 2), 2); + const __m128i q4h_3 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3 << 2), 2); + const __m128i q4h_4 = __lsx_vandi_b(q4bitsH_0, 3 << 4); + const __m128i q4h_5 = __lsx_vandi_b(q4bitsH_1, 3 << 4); + const __m128i q4h_6 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_0, 3 << 6), 2); + const __m128i q4h_7 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_1, 3 << 6), 2); + + const __m128i q4bits1_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits1_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + + const __m128i q4_0 = __lsx_vor_v(__lsx_vandi_b(q4bits1_0, 0xf), q4h_0); + const __m128i q4_1 = __lsx_vor_v(__lsx_vandi_b(q4bits1_1, 0xf), q4h_1); + const __m128i q4_2 = __lsx_vor_v(__lsx_vandi_b(q4bits2_0, 0xf), q4h_2); + const __m128i q4_3 = __lsx_vor_v(__lsx_vandi_b(q4bits2_1, 0xf), q4h_3); + const __m128i q4_4 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_0, 4), q4h_4); + const __m128i q4_5 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_1, 4), q4h_5); + const __m128i q4_6 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_0, 4), q4h_6); + const __m128i q4_7 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_1, 4), q4h_7); + + const __m128i q8_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_2 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_3 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_4 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_5 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_6 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_7 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + + __m128i p16_0 = lsx_maddubs_h(__lsx_vsub_b(q4_0, m32s), q8_0); + __m128i p16_1 = lsx_maddubs_h(__lsx_vsub_b(q4_1, m32s), q8_1); + __m128i p16_2 = lsx_maddubs_h(__lsx_vsub_b(q4_2, m32s), q8_2); + __m128i p16_3 = lsx_maddubs_h(__lsx_vsub_b(q4_3, m32s), q8_3); + __m128i p16_4 = lsx_maddubs_h(__lsx_vsub_b(q4_4, m32s), q8_4); + __m128i p16_5 = lsx_maddubs_h(__lsx_vsub_b(q4_5, m32s), q8_5); + __m128i p16_6 = lsx_maddubs_h(__lsx_vsub_b(q4_6, m32s), q8_6); + __m128i p16_7 = lsx_maddubs_h(__lsx_vsub_b(q4_7, m32s), q8_7); + + const __m128i sc_vec = j == 0 ? scales_lo : scales_hi; + + p16_0 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 0), p16_0); + p16_1 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 1), p16_1); + p16_2 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 2), p16_2); + p16_3 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 3), p16_3); + p16_4 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 4), p16_4); + p16_5 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 5), p16_5); + p16_6 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 6), p16_6); + p16_7 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 7), p16_7); + + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_0, p16_2)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_1, p16_3)); + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_4, p16_6)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_5, p16_7)); + } + + __m128 p_0 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_0)); + __m128 p_1 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_1)); + acc_0 = __lsx_vfadd_s(p_0, acc_0); + acc_1 = __lsx_vfadd_s(p_1, acc_1); + } + + *s = hsum_float_4x4(acc_0, acc_1, (__m128)__lsx_vldi(0), (__m128)__lsx_vldi(0)); + #else UNUSED(x); UNUSED(y); @@ -2149,6 +2271,35 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v *s = hsum_float_8(accum); +#elif defined(__loongarch_sx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + + __m128 accum = (__m128)__lsx_vldi(0); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ++ib) { + const __m128i q4bits = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q8b_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8b_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q4b_0 = __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits, 0xf)); + const __m128i q4b_1 = __lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits, 4)); + const __m128i p16_0 = lsx_maddubs_h(q4b_0, q8b_0); + const __m128i p16_1 = lsx_maddubs_h(q4b_1, q8b_1); + const int16_t ls = (((x[ibl].scales_l[ib/2] >> ((ib & 1) * 4)) & 0xf) | ((sh & 0x3) << 4)) - 32; + sh >>= 2; + sumi = __lsx_vadd_w(lsx_madd_h(p16_0, __lsx_vreplgr2vr_h(ls)), sumi); + sumi = __lsx_vadd_w(lsx_madd_h(p16_1, __lsx_vreplgr2vr_h(ls)), sumi); + } + const float ds = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + accum = __lsx_vfadd_s(__lsx_vfmul_s(__lsx_vreplfr2vr_s(ds), __lsx_vffint_s_w(sumi)), accum); + } + + *s = ((v4f32)lsx_hadd_s(lsx_hadd_s(accum, accum), lsx_hadd_s(accum, accum)))[0]; + #else UNUSED(x); UNUSED(y); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 0deda930985..62e687201ef 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1125,25 +1125,12 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { #define GGML_F16_EPR 4 static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]); - tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]); - tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]); - tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]); - - return (__m128)__lsx_vld(tmp, 0); + return __lsx_vfcvtl_s_h(__lsx_vld((const void *)x, 0)); } static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = GGML_CPU_FP32_TO_FP16(arr[0]); - x[1] = GGML_CPU_FP32_TO_FP16(arr[1]); - x[2] = GGML_CPU_FP32_TO_FP16(arr[2]); - x[3] = GGML_CPU_FP32_TO_FP16(arr[3]); + __m128i a = __lsx_vfcvt_h_s(y, y); + memcpy(x, &a, sizeof(ggml_fp16_t) * 4); } #define GGML_F32Cx4 __m128 From bf74b557d2a960628ba380acb43da5396aa4cb96 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Sat, 30 May 2026 15:26:13 +0300 Subject: [PATCH 219/289] metal : restore im2col implementation for large kernels (llama/23901) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 24 +++-- ggml/src/ggml-metal/ggml-metal.metal | 106 +++++++++++----------- 3 files changed, 77 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba006d9b31a..5d4b10d34b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1732,6 +1732,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); @@ -1739,7 +1741,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + if (ne00*ne01 <= 1024) { + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + } else { + snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + } snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 206af227a2c..e2ce56e9e28 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3635,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); - const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + } else { + const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e772664ba91..4adf4614acb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4696,59 +4696,59 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; -// TODO: obsolete -- remove -//typedef void (im2col_ext_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template <typename T> -//kernel void kernel_im2col_ext( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] -// const int64_t KHW = (int64_t)args.KHW; -// -// const int64_t d = tgpig[0] / args.CHW; -// const int64_t chw = tgpig[0] % args.CHW; -// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) -// const int64_t HW = tgpig[0] % KHW; -// -// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; -// if (tpitg_0 >= args.N) { -// return; -// } -// -// const int64_t tpitg_1 = HW / args.KW; -// const int64_t tpitg_2 = HW % args.KW; -// -// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; -// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; -// -// const int64_t offset_dst = -// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + -// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; -// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; -// } -//} -// -//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; -//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; +// TODO: optimize +typedef void (im2col_ext_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template <typename T> +kernel void kernel_im2col_ext( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; template <typename TK> kernel void kernel_conv_2d( From 1c0d1f0f7c40a18670a460cc23b654ceb679ba9a Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Sat, 30 May 2026 10:17:47 -0700 Subject: [PATCH 220/289] opencl: support bf16 by converting to f16 (llama/23839) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 81 +++++++++++++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 42 +++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 751ec6116c0..3f3643a4cef 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -585,6 +585,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; + cl_kernel kernel_convert_bf16_to_f16, kernel_convert_f16_to_bf16; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; @@ -1175,6 +1176,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_bf16_to_f16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_bf16_to_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_f16_to_bf16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_f16_to_bf16", &err), err)); GGML_LOG_CONT("."); } @@ -5019,6 +5022,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT: if (op->src[0]->type == GGML_TYPE_F16) { return true; + } else if (op->src[0]->type == GGML_TYPE_BF16) { + return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || @@ -6828,6 +6833,40 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + // convert bf16 to f16 and store as f16 in device buffer + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_dst = (extra->offset + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, + size, (void *) data, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + CL_CHECK(clReleaseEvent(evt)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; GGML_ASSERT(extra); @@ -7676,6 +7715,41 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_src = (extra->offset + tensor->view_offs + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_f16_to_bf16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, 0, size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; CL_CHECK(clEnqueueReadBuffer( @@ -8165,6 +8239,7 @@ static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor kernel = backend_ctx->kernel_cpy_f32_f32; break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: // stored as f16 on device kernel = backend_ctx->kernel_cpy_f16_f16; break; default: @@ -11125,7 +11200,8 @@ static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { return false; } - if (src0->type != GGML_TYPE_F16 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_BF16) || + src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { return false; } if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { @@ -12843,7 +12919,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; + // bf16 is stored as f16 on device + const enum ggml_type src0t = (src0->type == GGML_TYPE_BF16) ? GGML_TYPE_F16 : src0->type; const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c25eabdd72b..4f01887efb3 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -117,6 +117,48 @@ struct block_iq4_nl uint8_t qs[QK4_NL / 2]; }; +//------------------------------------------------------------------------------ +// bf16 to f16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_bf16_to_f16( + global const ushort * src, + global half * dst, + ulong off_dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + dst[i + off_dst] = (half) as_float((uint) src[i] << 16); +} + +//------------------------------------------------------------------------------ +// f16 to bf16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_f16_to_bf16( + global const half * src, + ulong off_src, + global ushort * dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + float f = (float) src[i + off_src]; + uint bits = as_uint(f); + if ((bits & 0x7fffffffu) > 0x7f800000u) { + // nan to quiet nan + dst[i] = (ushort)((bits >> 16) | 0x40u); + } else { + uint rounded = bits + 0x7fffu + ((bits >> 16) & 1u); + dst[i] = (ushort)(rounded >> 16); + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). From 687fbcb149c8e28bec2f563c1e9a081872b4e44c Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:50:55 +0800 Subject: [PATCH 221/289] sycl : Optimize Q3_K mul_mat by reorder (llama/23725) --- ggml/src/ggml-sycl/convert.cpp | 25 ++++++- ggml/src/ggml-sycl/dequantize.hpp | 57 ++++++++++++++ ggml/src/ggml-sycl/dmmv.cpp | 120 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 52 +++++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 30 +++++++- ggml/src/ggml-sycl/quants.hpp | 25 +++++++ ggml/src/ggml-sycl/vecdotq.hpp | 35 +++++++++ 7 files changed, 340 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 576f19d79ae..65593402e7d 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -107,6 +107,19 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template <typename dst_t> +static void dequantize_row_q3_K_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K_reorder(vx, y, item_ct1, nb); + }); +} + template <typename dst_t> static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -652,7 +665,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q4_K_sycl_reorder; @@ -730,7 +747,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 2324bfacd22..a723d2afbd6 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -390,6 +390,63 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri } +template<typename dst_t> +static void dequantize_block_q3_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { +#if QK_K == 256 + const int64_t i = item_ct1.get_group(2); + if (i >= n_blocks) { + return; + } + + const uint8_t * base = static_cast<const uint8_t *>(vx); + const size_t qs_offset = i * (QK_K / 4); + const size_t hmask_offset = n_blocks * (QK_K / 4) + i * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + i * 12; + const size_t d_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + n_blocks * 12 + + i * sizeof(ggml_half); + + const uint8_t * qs = base + qs_offset; + const uint8_t * hmask = base + hmask_offset; + const uint8_t * scales = base + scales_offset; + const float d_all = static_cast<float>(*reinterpret_cast<const ggml_half *>(base + d_offset)); + + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r / 2; + const int64_t is0 = r % 2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4 * n; + const int64_t is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + uint8_t m = 1 << (4 * n + j); + + uint8_t us = is < 4 + ? (scales[is - 0] & 0xF) | (((scales[is + 8] >> 0) & 3) << 4) + : is < 8 + ? (scales[is - 0] & 0xF) | (((scales[is + 4] >> 2) & 3) << 4) + : is < 12 + ? (scales[is - 8] >> 4) | (((scales[is + 0] >> 4) & 3) << 4) + : (scales[is - 8] >> 4) | (((scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + + dst_t * y = yy + i * QK_K + 128 * n + 32 * j; + const uint8_t * q = qs + 32 * n; + const uint8_t * hm = hmask; + + for (int l = l0; l < l0 + 4; ++l) { + y[l] = dl * ((int8_t) ((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(item_ct1); + GGML_UNUSED(n_blocks); + GGML_ABORT("Q3_K reorder dequantize not supported for QK_K != 256"); +#endif +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4ae431a962e..d80b0a38219 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -501,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q3_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * hmask_base = qs_base + (size_t)nb * (QK_K / 4); + const uint8_t * scales_base = hmask_base + (size_t)nb * (QK_K / 8); + const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t)nb * 12); + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = qs_base + bi * (QK_K / 4) + q_offset; + const uint8_t * h = hmask_base + bi * (QK_K / 8) + l0; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * 12); + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = d_base[bi]; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(ncols); + GGML_UNUSED(item_ct1); + GGML_ABORT("Q3_K reorder DMMV not supported for QK_K != 256"); +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:6: The total declared local variable size in device function dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register @@ -1440,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q3_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, float *dst, const int ncols, const int nrows, @@ -1581,7 +1694,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q3_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 729a88b4db8..e59f5c174d3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3549,6 +3549,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: return true; + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3572,6 +3573,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3791,6 +3793,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d return true; } +static bool reorder_qw_q3_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q3_K) == 0); + GGML_ASSERT(offset % sizeof(block_q3_K) == 0); + + const int nblocks = size / sizeof(block_q3_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * hmask_ptr = qs_ptr + (QK_K / 4) * nblocks; + auto * scales_ptr = hmask_ptr + (QK_K / 8) * nblocks; + sycl::half * d_ptr = (sycl::half *) (scales_ptr + 12 * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q3_K * x = (const block_q3_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 4; ++j) { + qs_ptr[ib * (QK_K / 4) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + hmask_ptr[ib * (QK_K / 8) + j] = x[ib].hmask[j]; + } + + for (int j = 0; j < 12; ++j) { + scales_ptr[ib * 12 + j] = x[ib].scales[j]; + } + + d_ptr[ib] = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q5_K) == 0); GGML_ASSERT(offset % sizeof(block_q5_K) == 0); @@ -3903,6 +3953,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q3_K: + return reorder_qw_q3_k(data_device, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q5_K: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 49998f13ba8..abd1e49a70e 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -770,6 +770,26 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1153,7 +1173,15 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, + stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 806028ef3a3..95287f17510 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -58,6 +58,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q3_K> { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI3_K; + static constexpr uint32_t qr = QR3_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + // Reordered layout: [qs (QK_K/4 per block)] [hmask (QK_K/8 per block)] [scales] [d] + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 4); + auto hmask_offset = n_blocks * (QK_K / 4) + block_index * (QK_K / 8); + return { qs_offset, hmask_offset }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 4) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * 12, + total_qs_bytes + nblocks * 12 + block_index * sizeof(ggml_half) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t<GGML_TYPE_Q4_K> { struct traits { static constexpr uint32_t qk = QK_K; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 16b2d65d271..4b58b09ab2c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -394,6 +394,41 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> { } }; +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q3_K; + + using q3_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q3_K>; + using q3_k_traits = typename q3_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * hmask = base + ibx_offset.second; + const uint8_t * scales = base + d_offset.first; + const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.second); + + const int bq8_offset = QR3_K * (iqs / (QI3_K / 2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); + + const int vl = get_int_from_uint8(qs, iqs); + const int vh = ~get_int_from_uint8(hmask, iqs % (QI3_K / 2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + u[i] = get_int_from_int8_aligned(quant_base_ptr, iqs % QI8_1); + d8[i] = (*(q8_1_ds + bq8_offset + i))[0]; + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, scales, scale_offset, static_cast<float>(d), d8); + } +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { From 20323e48c4612b508e4776ec6ef93cdefbc8325c Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:53:04 +0800 Subject: [PATCH 222/289] Add more types in GET_ROWS OP (llama/23710) * add to support Q1_0, NVFP4, IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ1_S, IQ1_M, IQ3_S, IQ4_NL, IQ4_XS, I32, MXFP4, Q2_K, Q3_K, Q5_K, and Q6_K in GET_ROWS OP * correct the link --- ggml/src/ggml-sycl/dequantize.hpp | 472 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/getrows.cpp | 78 ++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 18 ++ 3 files changed, 565 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index a723d2afbd6..ca8cd96c08c 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -20,6 +20,10 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, const int iqs, dfloat2 &v); +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m); +#endif + static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -90,6 +94,474 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q4_K * x = (const block_q4_K *) vx; + const sycl::half2 dm = x[ib].dm; + const float dall = dm[0]; + const float dmin = dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int off = in & 31; + const int qsi = 32 * il + off; + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const uint8_t q = x[ib].qs[qsi]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + return sycl::fma((dfloat) qv, (dfloat) (dall * sc), (dfloat) (-dmin * m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q4_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q2_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q2_K * x = (const block_q2_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int g = r / 32; + const int l = r % 32; + const int is = 8 * n + l / 16; + + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t sc = x[ib].scales[is + 2 * g]; + const float d = dall * (sc & 0xF); + const float m = dmin * (sc >> 4); + + return sycl::fma((dfloat) ((q >> (2 * g)) & 3), (dfloat) d, (dfloat) (-m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q2_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q3_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q3_K * x = (const block_q3_K *) vx; + const float d_all = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int j = r / 32; + const int l = r % 32; + + const int is0 = l / 16; + const int is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + const uint8_t m = 1 << (4 * n + j); + + const int8_t us = is < 4 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 8] >> 0) & 3) << 4) : + is < 8 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 4] >> 2) & 3) << 4) : + is < 12 ? (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is + 0] >> 4) & 3) << 4) : + (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t h = x[ib].hmask[l]; + const int8_t qv = ((q >> shift) & 3) - ((h & m) ? 0 : 4); + + return (dfloat) (dl * qv); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q3_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q5_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q5_K * x = (const block_q5_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int ir = (in & 31) / 2; + const int iq = in & 1; + + const uint8_t q = x[ib].qs[32 * il + 2 * ir + iq]; + const uint8_t h = x[ib].qh[2 * ir + iq]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const float d = dall * sc; + const float mn = dmin * m; + const uint8_t hm = 1 << (2 * il + (in >= 32 ? 1 : 0)); + + return sycl::fma((dfloat) (qv + ((h & hm) ? 16 : 0)), (dfloat) d, (dfloat) (-mn)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q5_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q6_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q6_K * x = (const block_q6_K *) vx; + const float d = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ip = idx / 128; + const int in = idx % 128; + const int il = in & 31; + const int ig = in / 32; + const int is = 8 * ip + il / 16; + + const uint8_t ql0 = x[ib].ql[64 * ip + il]; + const uint8_t ql1 = x[ib].ql[64 * ip + il + 32]; + const uint8_t qh = x[ib].qh[32 * ip + il]; + const int8_t * sc = x[ib].scales + is; + + uint8_t qv; + int8_t scale; + if (ig == 0) { + qv = (ql0 & 0xF) | (((qh >> 0) & 3) << 4); + scale = sc[0]; + } else if (ig == 1) { + qv = (ql1 & 0xF) | (((qh >> 2) & 3) << 4); + scale = sc[2]; + } else if (ig == 2) { + qv = (ql0 >> 4) | (((qh >> 4) & 3) << 4); + scale = sc[4]; + } else { + qv = (ql1 >> 4) | (((qh >> 6) & 3) << 4); + scale = sc[6]; + } + + return (dfloat) (d * scale * ((int8_t) qv - 32)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q6_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_mxfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_mxfp4 * x = (const block_mxfp4 *) vx; + const float d = ggml_sycl_e8m0_to_fp32(x[ib].e); + const uint8_t q = x[ib].qs[iqs]; + + v.x() = d * kvalues_mxfp4[q & 0xF] * 0.5f; + v.y() = d * kvalues_mxfp4[q >> 4] * 0.5f; +} + +static __dpct_inline__ void dequantize_q1_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q1_0 * x = (const block_q1_0 *) vx; + const dfloat d = x[ib].d; + + const int bit_index_0 = iqs + 0; + const int bit_index_1 = iqs + 1; + + const int bit_0 = (x[ib].qs[bit_index_0 / 8] >> (bit_index_0 % 8)) & 1; + const int bit_1 = (x[ib].qs[bit_index_1 / 8] >> (bit_index_1 % 8)) & 1; + + v.x() = (2 * bit_0 - 1) * d; + v.y() = (2 * bit_1 - 1) * d; +} + +static __dpct_inline__ void dequantize_nvfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_nvfp4 & xb = ((const block_nvfp4 *) vx)[ib]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int sub = idx / QK_NVFP4_SUB; + const int j = idx % QK_NVFP4_SUB; + const int jh = j % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + jh]; + const uint8_t qv = (j < (QK_NVFP4_SUB / 2)) ? (q & 0x0F) : (q >> 4); + + return d * kvalues_mxfp4[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq2_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * aux8 = (const uint8_t *) q2; + const uint8_t * grid = (const uint8_t *) (iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * grid = (const uint8_t *) (iq2xs_grid + (q2[il] & 511)); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_s * x = (const block_iq2_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 0x300); + const uint8_t * grid = (const uint8_t *) (iq2s_grid + grid_id); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = x[ib].qs[QK_K / 8 + 4 * ib8 + il]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * q3 = x[ib].qs + 8 * ib8; + const uint16_t * gas = (const uint16_t *) (x[ib].qs + QK_K / 4) + 2 * ib8; + const uint8_t * grid1 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 0]); + const uint8_t * grid2 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_s * x = (const block_iq3_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * qs = x[ib].qs + 8 * ib8; + const uint16_t grid1_id = qs[2 * il + 0] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 256); + const uint16_t grid2_id = qs[2 * il + 1] | ((x[ib].qh[ib8] << (7 - 2 * il)) & 256); + const uint8_t * grid1 = (const uint8_t *) (iq3s_grid + grid1_id); + const uint8_t * grid2 = (const uint8_t *) (iq3s_grid + grid2_id); + const float d = (float) x[ib].d * (1 + 2 * ((x[ib].scales[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf)); + const uint8_t signs = x[ib].signs[4 * ib8 + il]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_s * x = (const block_iq1_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const float delta = (x[ib].qh[ib8] & 0x8000) ? (-1.f - IQ1S_DELTA) : (-1.f + IQ1S_DELTA); + const float d = (float) x[ib].d * (2 * ((x[ib].qh[ib8] >> 12) & 7) + 1); + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((x[ib].qh[ib8] >> (3 * il)) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_m(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_m * x = (const block_iq1_m *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * sc = (const uint16_t *) x[ib].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + const int ib16 = 2 * ib8 + il / 2; + const float d = (float) scale.f16 * (2 * ((sc[ib16 / 4] >> (3 * (ib16 % 4))) & 0x7) + 1); + + const uint8_t qh = x[ib].qh[2 * ib8 + il / 2]; + const float delta = (qh & (0x08 << (4 * (il % 2)))) ? (-1.f - IQ1M_DELTA) : (-1.f + IQ1M_DELTA); + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((qh >> (4 * (il % 2))) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_M dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq4_nl(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_iq4_nl * x = (const block_iq4_nl *) vx; + const float d = (float) x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + if (idx < 16) { + return d * kvalues_iq4nl[x[ib].qs[idx] & 0xF]; + } + return d * kvalues_iq4nl[x[ib].qs[idx - 16] >> 4]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq4_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq4_xs * x = (const block_iq4_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int byte_idx = (r < 16) ? r : (r - 16); + const uint8_t q = x[ib].qs[16 * ib8 + byte_idx]; + const uint8_t qv = (r < 16) ? (q & 0x0F) : (q >> 4); + + const float d = (float) x[ib].d * ((((x[ib].scales_l[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf) | + (((x[ib].scales_h >> (2 * ib8)) & 3) << 4)) - 32); + return d * kvalues_iq4nl[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ4_XS dequantize not supported for QK_K != 256"); +#endif +} + static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index ca457454775..298f247f84e 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -129,11 +129,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_UNUSED(ctx); } -template <typename src0_t> +template <typename src0_t, typename dst_t> static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { + dst_t *dst_dd, queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -170,7 +170,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32 ); GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); @@ -191,6 +191,66 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_I32: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const int32_t *)dst->src[0]->data, + src1_i32, (int32_t *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q1_0: + get_rows_sycl<QK1_0, 1, dequantize_q1_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_MXFP4: + get_rows_sycl<QK_MXFP4, 2, dequantize_mxfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_NVFP4: + get_rows_sycl<QK_NVFP4, 1, dequantize_nvfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_S: + get_rows_sycl<QK_K, 1, dequantize_iq2_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq3_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_S: + get_rows_sycl<QK_K, 1, dequantize_iq1_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_M: + get_rows_sycl<QK_K, 1, dequantize_iq1_m>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_S: + get_rows_sycl<QK_K, 1, dequantize_iq3_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_NL: + get_rows_sycl<QK4_NL, 1, dequantize_iq4_nl>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_XS: + get_rows_sycl<QK_K, 1, dequantize_iq4_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q2_K: + get_rows_sycl<QK_K, 1, dequantize_q2_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q3_K: + get_rows_sycl<QK_K, 1, dequantize_q3_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q4_0: get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -199,6 +259,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q4_K: + get_rows_sycl<QK_K, 1, dequantize_q4_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q5_0: get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -207,6 +271,14 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q5_K: + get_rows_sycl<QK_K, 1, dequantize_q5_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q6_K: + get_rows_sycl<QK_K, 1, dequantize_q6_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q8_0: get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e59f5c174d3..96138f57ebe 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5301,13 +5301,31 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { + case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_0: return true; default: From ec0c6619500e86d2c2f290d1c95a3f022397fbcf Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:53:53 +0800 Subject: [PATCH 223/289] Support Q4_1, Q5_0, Q5_1 in Flash-attention (llama/23812) * support Q4_1, Q5_0, Q5_1 * update ut case --- ggml/src/ggml-sycl/common.hpp | 1 + ggml/src/ggml-sycl/fattn-common.hpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 31e26ff48e4..d8bb3638dfd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental; #define GGML_COMMON_IMPL_SYCL #define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. #define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building +#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention. /* suppress warning spam */ #pragma clang diagnostic push diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index 03f0c2623c8..c6cc13cfb00 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -1031,7 +1031,7 @@ void launch_fattn( auto KV_max_ptr_ct1 = KV_max.ptr; cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_mask_to_KV_max<ncols1, warp_size>( mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, @@ -1149,7 +1149,7 @@ void launch_fattn( auto K_ne_ct6 = K->ne[2]; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1, Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, @@ -1169,7 +1169,7 @@ void launch_fattn( auto KQV_data_ct2 = (float *) KQV->data; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_combine_results<DV>( dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks, From aea93ada610cf565e0585dfe2822cd4a2206a488 Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Mon, 1 Jun 2026 17:46:23 +0800 Subject: [PATCH 224/289] vulkan: Removed unused functions (llama/23175) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2a30fb95c61..74104149db8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7166,13 +7166,6 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx->s->buffer->buf.dispatch(wg0, wg1, wg2); } -static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) { - s.buffer->buf.end(); - - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); -} - static void ggml_vk_ctx_end(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); if (ctx->s == nullptr) { @@ -14510,12 +14503,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty UNUSED(buft); } -static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_VK_NAME "_Host"; - - UNUSED(buffer); -} - static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); From 982533fc0c38dabc6f7fa9155b7e33e5f565e223 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Mon, 1 Jun 2026 09:46:48 +0000 Subject: [PATCH 225/289] vulkan: Block-load Q3_K/Q6_K block data and subtract on 32b ints (llama/23056) Q2_K/Q3_K/Q6_K do much better when using MMVQ on Intel BMG even though they're only 2-byte aligned, and Q3_K still wins on NVIDIA as well. mesa isn't all that great at coalescing back-to-back loads from alternating arrays, so we force it instead. Further, we can do subtraction directly on a full int32_t rather than an i8vec4 with bit twiddling because the high bit is always free to start. On Intel BMG on mesa, the switch to MMVQ provides an immediate ~57% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and ~78% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. The futher switch to block loads leads to a ~24% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and a ~48% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. Finally, Xe2 wins on MMVQ even for small k, so we take the NVIDIA override for K quants on Xe2 as well. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 19 ++- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 108 +++++++++++------- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 74104149db8..3cf191f2085 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8336,8 +8336,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } - // General performance issue with q3_k and q6_k due to 2-byte alignment - if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + // q6_k only has 2-byte alignment which makes it somewhat problematic, + // using MMVQ is only a win on Intel. + bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL; + if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) { return false; } @@ -8349,7 +8351,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -8376,9 +8378,16 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (device->architecture == vk_device_architecture::INTEL_XE2) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return true; + } + } + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver MMVQ performance is worse than fp16, see - // https://github.com/ggml-org/llama.cpp/issues/17628 + // Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16, + // see https://github.com/ggml-org/llama.cpp/issues/17628 and + // https://github.com/ggml-org/llama.cpp/pull/23056 return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index bc580aeeb83..73cf9c79955 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -212,28 +212,40 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qs_shift = ((iqs_k % 32) / 8) * 2; const uint hm_shift = iqs_k / 8; + const uvec4 qs = uvec4( uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 7]) << 16)); + + const uvec4 hmask = uvec4( uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 ]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 7]) << 16)); + // bitwise OR to add 4 if hmask is set, subtract later - const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); + const uint vals0 = (( qs.x >> qs_shift) & 0x03030303) | + (((hmask.x >> hm_shift) & 0x01010101) << 2); + const uint vals1 = (( qs.y >> qs_shift) & 0x03030303) | + (((hmask.y >> hm_shift) & 0x01010101) << 2); + const uint vals2 = (( qs.z >> qs_shift) & 0x03030303) | + (((hmask.z >> hm_shift) & 0x01010101) << 2); + const uint vals3 = (( qs.w >> qs_shift) & 0x03030303) | + (((hmask.w >> hm_shift) & 0x01010101) << 2); + + // Subtract 4 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x04040404) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { @@ -343,27 +355,39 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qh_idx = (iqs_k / 32) * 8 + iqs; const uint qh_shift = ((iqs_k % 32) / 8) * 2; - const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); + const uvec4 ql = uvec4( uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 7]) << 16)); + + const uvec4 qh = uvec4( uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 7]) << 16)); + + const uint vals0 = (( ql.x >> ql_shift) & 0x0F0F0F0F) | + (((qh.x >> qh_shift) & 0x03030303) << 4); + const uint vals1 = (( ql.y >> ql_shift) & 0x0F0F0F0F) | + (((qh.y >> qh_shift) & 0x03030303) << 4); + const uint vals2 = (( ql.z >> ql_shift) & 0x0F0F0F0F) | + (((qh.z >> qh_shift) & 0x03030303) << 4); + const uint vals3 = (( ql.w >> ql_shift) & 0x0F0F0F0F) | + (((qh.w >> qh_shift) & 0x03030303) << 4); + + // Subtract 32 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x20202020) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { From e815b264eba131de2d06be039612fad9eb4330f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= <johannesg@5d6.de> Date: Mon, 1 Jun 2026 12:30:10 +0200 Subject: [PATCH 226/289] TP: quantized KV cache support (llama/23792) * TP: quantized KV cache support * fix partial view * remove overly strict assert --- ggml/include/ggml-backend.h | 10 +- ggml/src/ggml-backend-meta.cpp | 278 +++++++++++++++++---------------- 2 files changed, 149 insertions(+), 139 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index b6f73739809..2924fdbe988 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -381,11 +381,15 @@ extern "C" { // - most tensors have n_segments == 1 and a contiguous slice of the tensor data // - some tensors have an inhomogenenous data layout along the split axis, // those tensors are divided into segments which are each individually split across devices - // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, - // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - ne has one entry per segment and device and that segment repeats nr times, + // in total when accounting for repetitions the segments add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0_r0, seg0_dev1_r0, seg0_dev0_r1, seg0_dev1_r1, seg1_dev0_r0, seg1_dev1_r0], // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments - // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V, + // the Q matrix can be larger than the K and V matrices so this can either be expressed as 3 segments or as 2 segments + // where the segment for K/V repeats twice int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t nr[16]; uint32_t n_segments; }; diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 48b2027fac3..8c44c3e44ae 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -487,6 +487,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { + // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. + // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. + // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -497,11 +500,11 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( for (size_t j = 0; j < n_bufs; j++) { int64_t sum_a = 0; for (size_t s = 0; s < a.n_segments; s++) { - sum_a += a.ne[s*n_bufs + j]; + sum_a += a.ne[s*n_bufs + j] * a.nr[s]; } int64_t sum_b = 0; for (size_t s = 0; s < b.n_segments; s++) { - sum_b += b.ne[s*n_bufs + j]; + sum_b += b.ne[s*n_bufs + j] * b.nr[s]; } if (sum_a != sum_b) { return false; @@ -511,7 +514,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { - ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { continue; @@ -519,15 +522,15 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { ret = src_ss[i]; } else if (!split_states_equal(src_ss[i], ret)) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; break; } } if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return ret; @@ -571,42 +574,24 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_split_state ret = src_ss[0]; ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.nr[0] = 1; ret.n_segments = 1; return ret; } if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[1]; - ret.n_segments = 1; - return ret; + return src_ss[1]; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); - return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; } GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; - }; - - auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { - if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { - int64_t ne_split_src = tensor->src[0]->ne[0]; - for (int dim = 1; dim <= src_ss[0].axis; dim++) { - ne_split_src *= tensor->src[0]->ne[dim]; - } - int64_t ne_split_dst = 1; - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - ne_split_dst *= tensor->ne[dim]; - if (ne_split_dst == ne_split_src) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } - } - } - return handle_generic(src_ss, /*scalar_only =*/ false); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -615,33 +600,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); - if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { - return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; } - std::vector<int64_t> base_ne_in; - base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); - { - base_ne_in.push_back(1); - int dim = 0; - for (; dim <= src_ss[0].axis; dim++) { - base_ne_in[0] *= tensor->src[0]->ne[dim]; - } - for (; dim <= GGML_MAX_DIMS; dim++) { - base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); - } + int64_t base_ne_in = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; } + base_ne_in /= src_ss[0].nr[0]; int64_t base_ne_out = 1; for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; - for (const int64_t & bni : base_ne_in) { - if (bni == base_ne_out_next) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } + if (base_ne_out_next % base_ne_in == 0) { + return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; } - if (base_ne_out_next > base_ne_in[0]) { - GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); - return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(src_ss[0].n_segments == 1); + GGML_ASSERT(src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } base_ne_out = base_ne_out_next; } @@ -653,11 +630,18 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; + auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + return handle_reshape(src_ss); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { return handle_reshape(src_ss); @@ -681,7 +665,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } } GGML_ABORT("fatal error"); @@ -690,7 +674,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return src_ss[0]; } GGML_ABORT("view of permuted tensor not implemented"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -699,7 +683,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_MIRRORED: case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { @@ -707,7 +692,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -716,7 +701,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( switch (src_ss[0].axis) { case GGML_BACKEND_SPLIT_AXIS_0: case GGML_BACKEND_SPLIT_AXIS_1: { - return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: @@ -726,7 +712,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -764,16 +750,16 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; }; auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == src_ss[1].axis) { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; } } return handle_generic(src_ss, /*scalar_only =*/ false); @@ -781,8 +767,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_ss[0]; } GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -793,12 +779,12 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { if (ggml_nelements(tensor) == 0) { - return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); @@ -807,19 +793,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; int64_t ne_sum = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - GGML_ASSERT(ret.ne[sj] % granularity == 0); - ne_sum += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); + ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); } return ret; } - std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { - src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; continue; } src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); @@ -829,7 +817,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_split_state split_state; switch (tensor->op) { case GGML_OP_NONE: { - split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } break; case GGML_OP_DUP: { split_state = handle_generic(src_ss, /*scalar_only =*/ true); @@ -1016,7 +1004,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } break; default: { GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); - split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } break; } if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { @@ -1034,23 +1022,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state.ne[s*n_bufs + j] = 0; } for (size_t s = 0; s < src_ss[i].n_segments; s++) { - split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } split_state.ne[j] *= tensor->ne[split_state.axis]; if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { - GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); - split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; + GGML_ASSERT(split_state.ne[j] % div == 0); + split_state.ne[j] /= div; } } } else { + GGML_ASSERT(split_state.n_segments == 1); for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: int64_t sum = 0; for (size_t s = 0; s < src_ss[i].n_segments; s++) { - sum += src_ss[i].ne[s*n_bufs + j]; + sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } - // Assert that ratio is consistent: - GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] - == sum * tensor->ne[split_state.axis]); + GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); } } first_src_split_by_axis = false; @@ -1080,13 +1070,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( srcs_info += ", "; } const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + GGML_ASSERT(split_state.n_segments == 1); const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); std::string ne_info; for (size_t j = 0; j < n_bufs; j++) { if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(split_state.ne[j]); + ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); } srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; } @@ -1095,7 +1086,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; + ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); } GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); @@ -1107,8 +1099,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( #ifndef NDEBUG if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { int64_t ne_ret = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - ne_ret += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } assert(ne_ret == tensor->ne[int(ret.axis)]); } @@ -1155,7 +1149,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); ne[split_dim] = 0; for (size_t s = 0; s < split_state.n_segments; s++) { - ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; } for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -1229,7 +1223,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m for (size_t j = 0; j < n_simple_bufs; j++) { int64_t ne_sum = 0; for (size_t s = 0; s < split_state_src.n_segments; s++) { - ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; } if (ne_sum == 0) { simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; @@ -1255,8 +1249,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1267,24 +1262,26 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1292,22 +1289,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1365,8 +1364,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1377,24 +1377,26 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1402,22 +1404,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1675,6 +1679,7 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -1719,6 +1724,7 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: From c471bcce1b2d7cdaf372d621b64b1275cb7a01d8 Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Mon, 1 Jun 2026 20:03:32 +0800 Subject: [PATCH 227/289] vulkan: reduce host memory lock contention (llama/23376) * vulkan: reduces lock contention * replace unique_lock with lock_guard --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cf191f2085..c3d4c7a7129 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -62,6 +62,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { #include <map> #include <set> #include <unordered_map> +#include <shared_mutex> #include <mutex> #include <future> #include <thread> @@ -618,6 +619,7 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie struct vk_device_struct { std::recursive_mutex mutex; + mutable std::shared_mutex pinned_memory_mutex; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -7010,7 +7012,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -7021,7 +7023,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); vk_buffer buf; size_t index; @@ -7045,7 +7047,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { From 71d80aa49eb93868a8ed7e9f8abeae9e061adcfe Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Mon, 1 Jun 2026 07:04:01 -0500 Subject: [PATCH 228/289] vulkan: don't hold the device mutex while compiling pipelines (llama/23641) * vulkan: don't hold the device mutex while compiling pipelines We need to hold a lock while we traverse all pipelines and lazily initialize them, but we don't need to hold it while the pipeline is being compiled. And it doesn't need to be the same lock as the device mutex. We call load_shaders each time a pipeline is needed, so we only need to compile that one pipeline (and, for example, don't want to end up compiling a pipeline that another thread should be compiling). * remove 'needed' --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 144 ++++++++++++++++++--------- 1 file changed, 99 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c3d4c7a7129..e7d04634b8a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -65,6 +65,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { #include <shared_mutex> #include <mutex> #include <future> +#include <condition_variable> #include <thread> #if defined(_MSC_VER) @@ -159,8 +160,9 @@ struct vk_pipeline_struct { uint32_t align; // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; - // set to true to request the pipeline is compiled - std::atomic<bool> needed {}; + // true while a compile is in flight, used to dedupe concurrent claims. + // Protected by device->compile_mutex. + bool compile_pending {}; // set to true when the shader has been compiled std::atomic<bool> compiled {}; // number of registers used, extracted from pipeline executable properties @@ -621,6 +623,13 @@ struct vk_device_struct { std::recursive_mutex mutex; mutable std::shared_mutex pinned_memory_mutex; + // Guards compile_pending, all_pipelines, and the dynamic pipeline maps + // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile + // runs with no lock held, so different pipelines can compile in parallel. + // Lock order is device->mutex -> compile_mutex, never the reverse. + std::mutex compile_mutex; + std::condition_variable compile_cv; + vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; std::string name; @@ -1729,7 +1738,7 @@ struct ggml_vk_garbage_collector { }; static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); -static void ggml_vk_load_shaders(vk_device& device); +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); static bool vk_memory_logger_enabled = false; @@ -2196,11 +2205,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { ctx->device->device.resetFences({ ctx->fence }); } -// variables to track number of compiles in progress -static uint32_t compile_count = 0; -static std::mutex compile_count_mutex; -static std::condition_variable compile_count_cond; - static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; @@ -2495,7 +2499,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::cerr << "ggml_vulkan: " << e.what() << std::endl; throw e; } - pipeline->compiled = true; if (vk_instance.debug_utils_support) { vk::DebugUtilsObjectNameInfoEXT duoni; @@ -2544,14 +2547,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - device->all_pipelines.push_back(pipeline); - { - std::lock_guard<std::mutex> guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; + std::lock_guard<std::mutex> guard(device->compile_mutex); + device->all_pipelines.push_back(pipeline); + pipeline->compiled = true; + pipeline->compile_pending = false; } - compile_count_cond.notify_all(); + device->compile_cv.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -2567,8 +2569,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { - pipeline->needed = true; - ggml_vk_load_shaders(ctx->device); + ggml_vk_load_shaders(ctx->device, pipeline); } ggml_pipeline_allocate_descriptor_sets(ctx); } @@ -3567,10 +3568,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type #endif } -static void ggml_vk_load_shaders(vk_device& device) { +// load_shaders walks the pipeline list under compile_mutex and either claims +// the requested pipeline for compilation or, if another thread is already +// compiling it, drops the lock and waits on compile_cv. Compiles themselves +// run unlocked. +struct CompileTask { + vk_pipeline pipeline; + size_t spv_size; + const void * spv_data; + std::string entrypoint; + uint32_t parameter_count; + std::array<uint32_t, 3> wg_denoms; + std::vector<uint32_t> specialization_constants; + bool disable_robustness; + bool require_full_subgroups; + uint32_t required_subgroup_size; +}; + +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -3600,6 +3617,15 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; uint32_t l_align, m_align, s_align; + + vk_pipeline wait_pipeline; + CompileTask claimed_task {}; + bool has_claimed_task = false; + + // The rest of the walk reads and writes shared device state, so hold the + // lock until we're done deciding what to compile. + std::unique_lock<std::mutex> compile_lock(device->compile_mutex); + if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id l_warptile = { 256, 128, 256, 64, 1 }; @@ -3785,7 +3811,6 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); } - std::vector<std::future<void>> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3819,23 +3844,33 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif } - if (!pipeline->needed || pipeline->compiled) { + // We only care about the pipeline this call asked for; the rest + // (including the 64-bit indexing variant) are handled by their + // own request_descriptor_sets / load_shaders calls. + if (pipeline.get() != requested.get()) { continue; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock<std::mutex> guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; + + if (pipeline->compiled) { + continue; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + wait_pipeline = pipeline; + + if (!pipeline->compile_pending) { + pipeline->compile_pending = true; + claimed_task.pipeline = pipeline; + claimed_task.spv_size = spv_size; + claimed_task.spv_data = spv_data; + claimed_task.entrypoint = entrypoint; + claimed_task.parameter_count = parameter_count; + claimed_task.wg_denoms = wg_denoms; + claimed_task.specialization_constants = specialization_constants; + claimed_task.disable_robustness = disable_robustness; + claimed_task.require_full_subgroups = require_full_subgroups; + claimed_task.required_subgroup_size = required_subgroup_size; + has_claimed_task = true; + } } }; @@ -5332,8 +5367,25 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - for (auto &c : compiles) { - c.wait(); + // Drop compile_mutex so other threads can walk while we compile. + compile_lock.unlock(); + + // Compile what we claimed; create_pipeline_func reacquires compile_mutex + // at the end to flip compile_pending/compiled and notify waiters. + if (has_claimed_task) { + auto & task = claimed_task; + ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data, + task.entrypoint, task.parameter_count, task.wg_denoms, + task.specialization_constants, task.disable_robustness, + task.require_full_subgroups, task.required_subgroup_size); + } + + // Another thread may be compiling the pipeline we need; block on it here. + if (wait_pipeline) { + std::unique_lock<std::mutex> wait_lock(device->compile_mutex); + device->compile_cv.wait(wait_lock, [&] { + return wait_pipeline->compiled.load(); + }); } } @@ -9722,7 +9774,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { @@ -9786,13 +9838,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline_fa_mask_opt = nullptr; if (use_mask_opt) { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_fa_mask_opt; - auto it = pipelines.find({Br, Bc}); - if (it != pipelines.end()) { - pipeline_fa_mask_opt = it->second; - } else { - pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + { + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + } } assert(pipeline_fa_mask_opt); ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); @@ -10326,7 +10380,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); if (it != ctx->device->pipeline_solve_tri_f32.end()) { pipeline = it->second; @@ -10485,7 +10539,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = pipelines->find(conv2d_pipeline_state); if (it != pipelines->end()) { pipeline = it->second; From 050b8567a0fff75392c249d9283f8ee2dfa89292 Mon Sep 17 00:00:00 2001 From: Shrivas Shankar <86219405+shrivasshankar@users.noreply.github.com> Date: Mon, 1 Jun 2026 07:40:28 -0500 Subject: [PATCH 229/289] metal: template GLU kernels to support f16/f32 (llama/23882) Drops the hardcoded f32 GLU kernels in favor of a single template. We now load/store in the native tensor type (half or float) to save memory bandwidth, but keep the actual ALU compute in float to avoid exploding math in geglu/swiglu. Also opened up the dispatch gate to allow f16 inputs. --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 96 +++++++++++++++++-------- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 885344ec670..196af102643 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1107,7 +1107,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_1(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4adf4614acb..2bd310d9450 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1421,7 +1421,8 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; -kernel void kernel_reglu_f32( +template<typename T> +kernel void kernel_reglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1429,19 +1430,25 @@ kernel void kernel_reglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; const float x1 = src1_row[i0]; - dst_row[i0] = x0*x1*(x0 > 0.0f); + dst_row[i0] = (T)(x0*x1*(x0 > 0.0f)); } } -kernel void kernel_geglu_f32( +typedef decltype(kernel_reglu<float>) kernel_reglu_t; + +template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>; +template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>; + +template<typename T> +kernel void kernel_geglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1449,9 +1456,9 @@ kernel void kernel_geglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1459,11 +1466,17 @@ kernel void kernel_geglu_f32( const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); - dst_row[i0] = gelu*x1; + dst_row[i0] = (T)(gelu*x1); } } -kernel void kernel_swiglu_f32( +typedef decltype(kernel_geglu<float>) kernel_geglu_t; + +template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>; +template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>; + +template<typename T> +kernel void kernel_swiglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1471,9 +1484,9 @@ kernel void kernel_swiglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1481,11 +1494,17 @@ kernel void kernel_swiglu_f32( const float silu = x0 / (1.0f + exp(-x0)); - dst_row[i0] = silu*x1; + dst_row[i0] = (T)(silu*x1); } } -kernel void kernel_swiglu_oai_f32( +typedef decltype(kernel_swiglu<float>) kernel_swiglu_t; + +template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>; +template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>; + +template<typename T> +kernel void kernel_swiglu_oai( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1493,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { float x0 = src0_row[i0]; @@ -1507,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32( float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); out_glu = out_glu * (1.0f + x1); - dst_row[i0] = out_glu; + dst_row[i0] = (T)out_glu; } } -kernel void kernel_geglu_erf_f32( +typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t; + +template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>; +template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>; + +template<typename T> +kernel void kernel_geglu_erf( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1519,9 +1544,9 @@ kernel void kernel_geglu_erf_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1529,11 +1554,17 @@ kernel void kernel_geglu_erf_f32( const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV)); - dst_row[i0] = gelu_erf*x1; + dst_row[i0] = (T)(gelu_erf*x1); } } -kernel void kernel_geglu_quick_f32( +typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t; + +template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>; +template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>; + +template<typename T> +kernel void kernel_geglu_quick( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1541,9 +1572,9 @@ kernel void kernel_geglu_quick_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1551,10 +1582,15 @@ kernel void kernel_geglu_quick_f32( const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); - dst_row[i0] = gelu_quick*x1; + dst_row[i0] = (T)(gelu_quick*x1); } } +typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t; + +template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>; +template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>; + kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, From e728bae15950e1786b4c9574fa56a889e772f516 Mon Sep 17 00:00:00 2001 From: shaofeiqi <shaoqi@qti.qualcomm.com> Date: Mon, 1 Jun 2026 10:06:50 -0700 Subject: [PATCH 230/289] opencl: add basic support for q5_0 and q5_1 (llama/23548) * opencl: add general q5_0 support * opencl: add general q5_1 support * opencl: support non-uniform workgrp size --------- Co-authored-by: Li He <lih@qti.qualcomm.com> --- ggml/src/ggml-opencl/CMakeLists.txt | 6 + ggml/src/ggml-opencl/ggml-opencl.cpp | 422 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 100 +++++ .../kernels/mul_mm_q5_0_f32_l4_lm.cl | 173 +++++++ .../kernels/mul_mm_q5_1_f32_l4_lm.cl | 175 ++++++++ .../ggml-opencl/kernels/mul_mv_q5_0_f32.cl | 241 ++++++++++ .../kernels/mul_mv_q5_0_f32_flat.cl | 243 ++++++++++ .../ggml-opencl/kernels/mul_mv_q5_1_f32.cl | 243 ++++++++++ .../kernels/mul_mv_q5_1_f32_flat.cl | 247 ++++++++++ 9 files changed, 1845 insertions(+), 5 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 446fb727996..cd15d573238 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -87,6 +87,10 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q4_k_f32_flat + mul_mv_q5_0_f32 + mul_mv_q5_0_f32_flat + mul_mv_q5_1_f32 + mul_mv_q5_1_f32_flat mul_mv_q5_k_f32 mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 @@ -126,6 +130,8 @@ set(GGML_OPENCL_KERNELS mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm + mul_mm_q5_0_f32_l4_lm + mul_mm_q5_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_iq4_nl_f32_l4_lm mul_mm_q4_k_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3f3643a4cef..7cafbe0cdc3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -576,7 +576,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0, kernel_restore_block_q5_0; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1, kernel_restore_block_q5_1; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; @@ -604,6 +606,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q5_0_f32; + cl_kernel kernel_mul_mv_q5_0_f32_flat; + cl_kernel kernel_mul_mv_q5_1_f32; + cl_kernel kernel_mul_mv_q5_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q4_K_f32_flat; cl_kernel kernel_mul_mv_q5_K_f32; @@ -662,6 +668,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_l4_lm; cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; @@ -1141,8 +1149,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); @@ -1485,6 +1497,74 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mv_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_0_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q5_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1835,6 +1915,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mm_q5_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q5_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -5027,6 +5139,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || @@ -5977,7 +6090,24 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - return; + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; } if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -6078,6 +6208,24 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -7135,8 +7283,29 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_0 - (void) extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_Q5_1) { @@ -7177,8 +7346,29 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_1 - (void) extra; + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -12936,6 +13126,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; @@ -13271,6 +13463,93 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q5_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -13807,6 +14086,137 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -14247,6 +14657,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 4f01887efb3..d07f0a1a025 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -537,6 +537,53 @@ kernel void kernel_restore_block_q4_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_0 +// Convert the block_q5_0 format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_0( + global struct block_q5_0 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_0( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_0/2; ++i) { + b->qs[i] = qs[i]; + } +} + kernel void kernel_convert_block_q5_0_trans4_ns( __global struct block_q5_0 * src0, __global uint * dst_qs, @@ -636,6 +683,59 @@ kernel void kernel_restore_block_q5_0_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_1 +// Convert the block_q5_1 format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_1( + global struct block_q5_1 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + global half * dst_m, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_1( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_1/2; ++i) { + b->qs[i] = qs[i]; + } +} + kernel void kernel_convert_block_q5_1_trans4_ns( __global struct block_q5_1 * src0, __global uint * dst_qs, diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl new file mode 100644 index 00000000000..1e980a478a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl @@ -0,0 +1,173 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_0_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs_ptr = src0_qs + ib*4 + iqs; + uchar4 q = *qs_ptr; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = (convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) - 16.0f) * d; + float4 v2 = (convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) - 16.0f) * d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl new file mode 100644 index 00000000000..ba06be54697 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl @@ -0,0 +1,175 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_1_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs = src0_qs + ib*4 + iqs; + uchar4 q = *qs; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) * d + m; + float4 v2 = convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) * d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl new file mode 100644 index 00000000000..6d8c9e8f037 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl @@ -0,0 +1,241 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +struct block_q5_0 { + half d; + uchar qh[4]; + uchar qs[QK5_0 / 2]; +}; + +inline float block_q5_0_dot_y( + global const struct block_q5_0 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 6 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 2)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_0 * x = (global struct block_q5_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl new file mode 100644 index 00000000000..34ec133d398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl @@ -0,0 +1,243 @@ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +inline float block_q5_0_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_0/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 0*nb*(QK5_0/2), qh + ib + 0*nb, d + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 1*nb*(QK5_0/2), qh + ib + 1*nb, d + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 2*nb*(QK5_0/2), qh + ib + 2*nb, d + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 3*nb*(QK5_0/2), qh + ib + 3*nb, d + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl new file mode 100644 index 00000000000..1480f675038 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl @@ -0,0 +1,243 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +struct block_q5_1 { + half d; + half m; + uchar qh[4]; + uchar qs[QK5_1 / 2]; +}; + +inline float block_q5_1_dot_y( + global const struct block_q5_1 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 8 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 4)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_1 * x = (global struct block_q5_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl new file mode 100644 index 00000000000..57c2f140958 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +inline float block_q5_1_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_1/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global half * ms = (global half *) src0_m + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 0*nb*(QK5_1/2), qh + ib + 0*nb, d + ib + 0*nb, ms + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 1*nb*(QK5_1/2), qh + ib + 1*nb, d + ib + 1*nb, ms + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 2*nb*(QK5_1/2), qh + ib + 2*nb, d + ib + 2*nb, ms + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 3*nb*(QK5_1/2), qh + ib + 3*nb, d + ib + 3*nb, ms + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} From db2a39507ccb33f5e9b2aedeefa9d149990d1f0a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura <yoshimura.masashi.frbs@gmail.com> Date: Tue, 2 Jun 2026 08:59:06 +0900 Subject: [PATCH 231/289] revert to using global_invocation_id for cpy shader (llama/23955) --- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index e268adfb16b..67f1dc0928f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -50,13 +50,13 @@ var<uniform> params: Params; @compute @workgroup_size(WG_SIZE) fn main( - @builtin(global_invocation_index) gindex: u32, + @builtin(global_invocation_id) gid: vec3<u32>, ) { - if (gindex >= params.ne) { + if (gid.x >= params.ne) { return; } - var i = gindex; + var i = gid.x; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); let i2 = i / (params.src_ne1 * params.src_ne0); @@ -64,7 +64,7 @@ fn main( let i1 = i / params.src_ne0; let i0 = i % params.src_ne0; - var j = gindex; + var j = gid.x; let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); let j2 = j / (params.dst_ne1 * params.dst_ne0); @@ -80,4 +80,3 @@ fn main( dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } - From 9a0265d13b890fffa18315f4da5de51147dc8ccd Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Mon, 1 Jun 2026 19:15:09 -0700 Subject: [PATCH 232/289] opencl: fix compiler warnings for non-adreno path (llama/23922) * opencl: fix compiler warnings for non-adreno path * opencl: fix const cast warning --- ggml/src/ggml-opencl/ggml-opencl.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7cafbe0cdc3..b67ea46bce8 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -380,7 +380,7 @@ struct ggml_backend_opencl_device_context { ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; std::regex *opfilter = nullptr; // regex of ops to not claim - std::string opfilter_str; // regex string for opfilter + std::string opfilter_str = ""; // regex string for opfilter size_t global_mem_size = 0; }; @@ -6822,9 +6822,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Adreno MoE Q6_K kernel needs special transposed layout if (use_adreno_moe_kernels(backend_ctx, tensor)) { @@ -6858,6 +6855,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -6994,7 +6994,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - size, (void *) data, &err); + size, const_cast<void *>(data), &err); CL_CHECK(err); cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; @@ -7782,9 +7782,6 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { cl_int err; @@ -7794,6 +7791,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -14888,6 +14888,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; + GGML_UNUSED(ne2); + const int r2 = ne12/ne02; const int r3 = ne13/ne03; const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows @@ -14902,6 +14904,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int n_tile_size = 32; const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + GGML_UNUSED(max_post_router_tile); + cl_kernel kernel; // subgroup mat vec From 79223704a1ec33e4147ab6f0d10934667cea7200 Mon Sep 17 00:00:00 2001 From: Anav Prasad <anavp@nvidia.com> Date: Mon, 1 Jun 2026 19:38:37 -0700 Subject: [PATCH 233/289] clean up unused variables warnings (llama/23975) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 6 +++--- ggml/src/ggml-cuda/gated_delta_net.cu | 10 +++++----- ggml/src/ggml-cuda/mmf.cuh | 6 +++--- ggml/src/ggml-cuda/mmvf.cu | 13 ++++++------- ggml/src/ggml-cuda/mmvq.cu | 13 ++++--------- ggml/src/ggml-cuda/topk-moe.cu | 2 +- 6 files changed, 22 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3c8b6eaaf24..ac5abb13367 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -568,7 +568,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; @@ -604,9 +603,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; - const int k0_diff = k0_stop - k0_start; if constexpr (nstages <= 1) { + const int k0_diff = k0_stop - k0_start; constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); @@ -640,6 +639,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { + constexpr int stride_tile_Q = DKQ/2 + 4; #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -954,9 +954,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); const int i0_stop = i0_start + 2*nbatch_V2; - const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { + const int i0_diff = i0_stop - i0_start; if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 018d5d37d47..7cfda652367 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -43,7 +43,6 @@ gated_delta_net_cuda(const float * q, // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; - const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; @@ -61,10 +60,6 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -148,6 +143,11 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output const int target_slot = t - shift; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a8d54c95a..d55cc1ec7b5 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -91,7 +91,7 @@ static __global__ void mul_mat_f( const int row0 = blockIdx.x * rows_per_block; int expert_idx = 0; - int col_base = 0; + [[maybe_unused]] int col_base = 0; const int channel_dst = has_ids ? 0 : blockIdx.y; @@ -122,12 +122,12 @@ static __global__ void mul_mat_f( ids += col_offset * stride_row_id; } - const float2 * y2 = (const float2 *) y; + [[maybe_unused]] const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; char * shmem_base = data_mmv; - int * slot_map = (int *) shmem_base; + [[maybe_unused]] int * slot_map = (int *) shmem_base; char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; tile_C C[ntA][ntB]; diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 09d95f309b4..3d6de64b775 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -80,9 +80,8 @@ static __global__ void mul_mat_vec_f( gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } - const int channel_bias = ids ? channel_x : channel_dst; - if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -95,7 +94,7 @@ static __global__ void mul_mat_vec_f( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - float * buf_iw_gate = nullptr; + [[maybe_unused]] float * buf_iw_gate = nullptr; if constexpr (has_fusion) { buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); } @@ -123,7 +122,7 @@ static __global__ void mul_mat_vec_f( if constexpr (std::is_same_v<T, float>) { const float2 * x2 = (const float2 *) x; - const float2 * gate_x2 = nullptr; + [[maybe_unused]] const float2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const float2 *) gate_x; @@ -155,7 +154,7 @@ static __global__ void mul_mat_vec_f( } } else if constexpr (std::is_same_v<T, half>) { const half2 * x2 = (const half2 *) x; - const half2 * gate_x2 = nullptr; + [[maybe_unused]] const half2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const half2 *) gate_x; @@ -266,7 +265,7 @@ static __global__ void mul_mat_vec_f( } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - const nv_bfloat162 * gate_x2 = nullptr; + [[maybe_unused]] const nv_bfloat162 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const nv_bfloat162 *) gate_x; @@ -274,7 +273,7 @@ static __global__ void mul_mat_vec_f( } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; - nv_bfloat162 tmpx_gate; + [[maybe_unused]] nv_bfloat162 tmpx_gate; if constexpr (has_fusion) { if (use_gate) { tmpx_gate = gate_x2[col2]; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ecb6fdedadd..86b4a493019 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -515,7 +515,7 @@ static __global__ void mul_mat_vec_q( bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; - const void * vgate = nullptr; + [[maybe_unused]] const void * vgate = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; ggml_glu_op active_glu; @@ -531,8 +531,8 @@ static __global__ void mul_mat_vec_q( } - float x_biases[ncols_dst] = { 0.0f }; - float gate_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float x_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { @@ -589,12 +589,7 @@ static __global__ void mul_mat_vec_q( } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if constexpr (!has_fusion) { - (void) tmp_shared_gate; - } else if (!use_gate) { - (void) tmp_shared_gate; - } + [[maybe_unused]] __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index da20c9aab7c..c4253bfa43b 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -134,7 +134,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values - float selection_wt[has_bias ? experts_per_thread : 1]; + [[maybe_unused]] float selection_wt[has_bias ? experts_per_thread : 1]; if constexpr (has_bias) { #pragma unroll From 754247f28b7615704a408ccd4c6331ab26c9d402 Mon Sep 17 00:00:00 2001 From: Todor Boinovski <todorb@qti.qualcomm.com> Date: Mon, 1 Jun 2026 23:19:07 -0700 Subject: [PATCH 234/289] hexagon: add gelu_quick (llama/24007) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 48ded82e83c..920829f6a93 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -3142,13 +3142,14 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { - case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; - case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; - case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; - case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; - case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; - case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; - case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_GELU_QUICK: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -3630,6 +3631,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: supp = ggml_hexagon_supported_activations(sess, op); break; default: From 8d61a9edf0b3b429671f1a7037d67243d941b517 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky <maxk@qti.qualcomm.com> Date: Mon, 1 Jun 2026 23:40:08 -0700 Subject: [PATCH 235/289] hexagon: MUL_MAT, MUL_MAT_ID, FLASH_ATTN and GDN cleanup and optimizations for latest models (llama/23989) * hex-mm: initial support for F32 * F32 -> F32 matmuls * hex-rms-norm: fix src1 stride use in fused rms_norm_mul * hex-ops: clear spad pointers in the ops that clober it This fixes an odd case where fused rms-norm-mul was failing but only in qwen3.5-2B and only at searth op-bath sizes. * hmx-mm: add support for F32 * F32 -> F32 matmul_2d on HMX Decided to use Q4_0 * F32 -> F32 matmul for this. Q4_0 gets dequantized and tiled into F16, and here we quantize and tile F32 into F16. Super simple and pretty efficient. * hmx-mm: route f16 2D matmuls through the same kernel used for all other types * hmx-mm: re-introduce pipelined vs non-pipelined mode that we used to have but is much more generic way This update futher improves matmul performance and at the same time removes most of the redudant logic we had in different paths. * hmx-fa: slighlty improved pipeline simimar to matmul updates * hmx-mm: initial version of MAT_MUL_ID support for HMX * hmx-mm: fixed mxfp4 handling for MUL_MAT_ID * hex-gdn: optimize GATED_DELTA_NET DMA prefetch/double-buff, vectorize everything with HVX, in other words -- the usual :) * hmx-mm: missed one more case where we can use fastmod * hexagon: update DCVS settings for a slight perf bump * hmx-fa: use fastdiv in hmx-flash-attn * hmx-fa: precompute slope values to avoid disrupting the inner loop * hvx-utils/fa: new HVX helpers for powf and logf and using those to speed up FA alibi * hex-ops: fixed a bug in fusion logic that was messing up the order of the src tensors when some srcs are empty * hex-fa: correctly fallback to HVX if we have sinks or the dims are not quite right --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 22 + ggml/src/ggml-hexagon/htp-opnode.h | 46 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 47 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 1 + ggml/src/ggml-hexagon/htp/concat-ops.c | 2 + ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 15 +- .../ggml-hexagon/htp/gated-delta-net-ops.c | 660 ++++++++----- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 116 ++- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 885 +++++++++++++----- ggml/src/ggml-hexagon/htp/hmx-ops.c | 6 + ggml/src/ggml-hexagon/htp/hmx-ops.h | 22 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 + ggml/src/ggml-hexagon/htp/hvx-flash-attn.h | 47 + ggml/src/ggml-hexagon/htp/hvx-log.h | 65 ++ ggml/src/ggml-hexagon/htp/hvx-pow.h | 42 + ggml/src/ggml-hexagon/htp/hvx-utils.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 26 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 390 +++++++- ggml/src/ggml-hexagon/htp/pad-ops.c | 2 + ggml/src/ggml-hexagon/htp/unary-ops.c | 17 +- 20 files changed, 1825 insertions(+), 592 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-flash-attn.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-log.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-pow.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 920829f6a93..d550841a2a5 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1927,6 +1927,7 @@ struct ggml_hexagon_opbatch { size_t extra_tens = 0; auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; if (!t_map.count(t)) { extra_tens++; @@ -2602,6 +2603,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } if (ggml_nrows(src1) > 1024) { return false; // no huge batches (for now) } diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 14b232240b4..8a1228ccdc0 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,7 +56,7 @@ struct htp_opnode { } std::vector<const ggml_tensor *> get_inputs() const { - std::vector<const ggml_tensor *> inputs; + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); std::vector<const ggml_tensor *> outputs; outputs.push_back(node); for (const auto * f : fused) { @@ -70,20 +70,38 @@ struct htp_opnode { return false; }; + int count = 0; auto add_input = [&](const ggml_tensor * t) { if (t && !contains(outputs, t) && !contains(inputs, t)) { - inputs.push_back(t); + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } } }; - for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { - add_input(node->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused.empty()) { + inputs[i] = node->src[i]; + } else { + if (node->src[i]) { + add_input(node->src[i]); + } + } } for (const auto * f : fused) { - for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { - add_input(f->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } } } + + if (!fused.empty()) { + inputs.resize(count); + } + return inputs; } @@ -108,6 +126,9 @@ struct htp_opformat { char names[64 * GGML_MAX_SRC]; int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } if (t->ne[2] == 1 && t->ne[3] == 1) { return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { @@ -136,6 +157,9 @@ struct htp_opformat { } int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { @@ -170,11 +194,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); } p += sprintf(p, " -> "); @@ -184,7 +208,7 @@ struct htp_opformat { } const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { + if (t && t->buffer) { return ggml_backend_buffer_name(t->buffer); } return "NONE"; @@ -213,11 +237,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0]->name); + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i]->name); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); } p += sprintf(p, " -> "); diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index ff3fc0804e3..f4b44fe1a65 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,27 +19,6 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c - matmul-ops.c - binary-ops.c - unary-ops.c - sum-rows-ops.c - softmax-ops.c - act-ops.c - rope-ops.c - flash-attn-ops.c - set-rows-ops.c - get-rows-ops.c - cpy-ops.c - repeat-ops.c - argsort-ops.c - ssm-conv.c - cumsum-ops.c - fill-ops.c - concat-ops.c - diag-ops.c - solve-tri-ops.c - gated-delta-net-ops.c - pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE @@ -58,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-flash-attn-ops.c hmx-queue.c ) @@ -76,6 +55,30 @@ endif() build_idl(htp_iface.idl ${HTP_LIB}) +target_sources(${HTP_LIB} PRIVATE + matmul-ops.c + binary-ops.c + unary-ops.c + sum-rows-ops.c + softmax-ops.c + act-ops.c + rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c + cpy-ops.c + repeat-ops.c + argsort-ops.c + ssm-conv.c + cumsum-ops.c + fill-ops.c + concat-ops.c + diag-ops.c + solve-tri-ops.c + gated-delta-net-ops.c + pad-ops.c +) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index bdd0623615d..73af38a35ab 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.size = total_spad_size; octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c index 61580f2c08f..f2a381313c5 100644 --- a/ggml/src/ggml-hexagon/htp/concat-ops.c +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -262,6 +262,8 @@ int op_concat(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; if (type_size == 4) { worker_func = concat_2d_f32_transposed; diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1bd8c1407de..e996214691a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" #include "hvx-dump.h" +#include "hvx-flash-attn.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -245,6 +246,7 @@ struct htp_fa_context { uint32_t n_head_log2; float m0; float m1; + float slopes[512]; uint32_t n_blocks; @@ -412,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + const float slope = factx->slopes[h]; HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); @@ -628,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; @@ -689,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + // total rows in q const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index c4d08bb21c4..3b092d7440d 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -3,6 +3,7 @@ #include <string.h> #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -14,106 +15,103 @@ #define HTP_GDN_MAX_SV 128 + struct htp_gdn_context { struct htp_ops_context * octx; uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; }; -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); const HVX_Vector vmul = hvx_vec_splat_f32(mul); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vd = hvx_vmemu(dst + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, @@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -605,26 +603,65 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; const int64_t shift = (int64_t) n_tokens - (int64_t) K; + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -640,57 +677,117 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } @@ -698,17 +795,40 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const int64_t target_slot = (int64_t) t - shift; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - if (curr_state_o != s_work) { - memcpy(curr_state_o, s_work, gctx->state_bytes); + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); } } } attn_data += (uint64_t) S_v * H; } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; struct htp_ops_context * octx = gctx->octx; @@ -743,41 +863,64 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; - } + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - float * s_work; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -792,111 +935,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), S_v * sizeof(float), S_v * sizeof(float), S_v * sizeof(float), S_v); - dma_queue_pop(dma); + + ir_prefetch += nth; + spad_idx ^= 1; } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + int op_gated_delta_net(struct htp_ops_context * octx) { const struct htp_tensor * q = octx->src[0]; const struct htp_tensor * k = octx->src[1]; @@ -952,18 +1129,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { size_t state_aligned = (size_t) S_v * S_v * sizeof(float); state_aligned = (state_aligned + 127) & ~(size_t)127; - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; if (n_tokens == 1) { worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index f132c08500d..2796564fb75 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -17,14 +17,17 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "hmx-profile.h" #include "hmx-queue.h" #include "hmx-utils.h" #include "htp-ctx.h" #include "htp-ops.h" #include "hvx-dump.h" +#include "hvx-copy.h" #include "hvx-reduce.h" #include "hvx-utils.h" +#include "hvx-flash-attn.h" #include "vtcm-utils.h" #include "worker-pool.h" @@ -46,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); // Approximate per-unit VTCM costs (without per-buffer alignment padding). const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors const size_t per_gbr2 = fp16; // D diagonal matrix const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs const size_t per_gbr_bc = 2 * fp16; // S + P const size_t overhead = 256 * 2 + 13 * 4096; @@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); // Cost coefficients calibrated from profiling @@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, } // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { Bc -= bc_unit; } if (Bc < bc_unit) { @@ -303,6 +306,7 @@ struct hmx_fa_context { uint32_t n_kv_heads; // number of KV heads uint32_t n_heads; // number of Q heads uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; uint32_t n_kv_blocks; uint32_t neq1; // Q token count @@ -321,7 +325,7 @@ struct hmx_fa_context { __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] @@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, (int) args->src_stride, (int) args->n_col_tiles, start, end); } @@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { for (size_t r = start; r < end; r += 2) { const bool next_row_valid = (r + 1) < n_rows_g; - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; @@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { const uint32_t ib3 = args->ib3; for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. @@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; + const size_t qi0 = fastdiv(r + 0, &factx->div_G); v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; + const size_t qi1 = fastdiv(r + 1, &factx->div_G); if (qi1 == qi0) { v_mask1 = v_mask0; // scheme B: reuse — same mask row } else { @@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); if (q_idx1 == q_idx0) { // scheme B: same mask row in DDR path v_mask1 = v_mask0; } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); const uint32_t im2_1 = h_idx1 % mask->ne[2]; const uint32_t im3_1 = args->ib3 % mask->ne[3]; const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + @@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) { // Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. // slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). // When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, +static __attribute__((noinline)) void fa_compute_slopes( const struct hmx_fa_context * factx, uint32_t kv_head, size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); return; } @@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg const float m0 = factx->m0; const float m1 = factx->m1; + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } + } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); } // ============================================================================ @@ -1254,19 +1281,22 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const uint32_t G = neq2 / n_kv_heads; // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; + const uint32_t n_threads_init = octx->n_threads; // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) size_t Br, Bc; const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { return HTP_STATUS_VTCM_TOO_SMALL; } const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); @@ -1282,6 +1312,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_heads = n_kv_heads; factx.n_heads = neq2; factx.G = G; + factx.div_G = init_fastdiv_values(G); factx.neq1 = neq1; factx.Br = (uint32_t) Br; factx.Bc = (uint32_t) Bc; @@ -1354,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); @@ -1457,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ---- KV block loop with DMA double-buffering ---- size_t buf_idx = 0; + fa_compute_slopes(&factx, kv_head, n_rows_g); + // Prefetch first KV block if (factx.n_kv_blocks > 0) { const uint32_t kv_rows0 = hex_smin(Bc, nek1); @@ -1535,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1550,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); TIMER_STOP(k_interleave); - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- qk_job.q_tiles = factx.vtcm_q_tiles; qk_job.k_tiles = factx.vtcm_k_tiles; @@ -1574,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); TIMER_STOP(v_interleave); + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job hmx_queue_pop(hmx_q); TIMER_STOP(qk_dot); @@ -1601,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1617,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1712,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1732,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t DV_tiles = (size_t) (DV / 32); const __fp16 * restrict d_base = factx.vtcm_d_tiles; const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; const __fp16 * restrict op_base = o_tile_prev; __fp16 * restrict oc_base = o_tile_curr; __builtin_assume(n_row_tiles > 0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 083d125882d..dab605210cf 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -73,6 +73,10 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); default: return 0; } @@ -545,7 +549,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( int start_tile, int end_tile) { const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; + const int qrow_size = (unsigned)state->k_block / 2; const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); @@ -720,12 +724,129 @@ static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, voi } } +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, size_t row_stride, int weight_type, int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn) { + worker_callback_t dequant_worker_fn, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); @@ -733,7 +854,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; - size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); x4x2_dequantize_state_t state; state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; @@ -748,7 +869,11 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_k_tiles = n_k_tiles; state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } } // --- End x4x2 dequantizers --- @@ -876,11 +1001,11 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void } static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n) { + int n_rows, int n_cols, int n, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -891,7 +1016,11 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, state.n_cols = n_cols; state.n = n; - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } } // activations : fp32 -> fp16 @@ -973,12 +1102,12 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, } } -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); assert(VLEN == 32 * sizeof(float)); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address activation_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -989,7 +1118,11 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * state.k_block = k_block; state.k_stride = k_stride; - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } } // C += AB @@ -1031,9 +1164,9 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co } } -int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { + int act_stride, int weight_stride, int weight_type) { if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { @@ -1052,6 +1185,8 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; default: return -1; } @@ -1059,21 +1194,25 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); return -1; } @@ -1083,27 +1222,27 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * size_t scratch0_size, scratch1_size, scratch2_size; scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); @@ -1114,115 +1253,137 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * TIMER_DEFINE(total); TIMER_START(total); - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); - } + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); - } - } + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - hmx_queue_suspend(ctx->hmx_queue); - TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1401,11 +1562,11 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_act_g, vtcm_f32_act, (int) n_rows, - params->k, params->k); + params->k, params->k, ctx->n_threads); } else { transfer_activation_chunk_threaded(ctx, vtcm_act_g, activation_chunk, (int) n_rows, - params->k, params->act_stride); + params->k, params->act_stride, ctx->n_threads); } } TIMER_STOP(activation_load); @@ -1455,7 +1616,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_START(output_store); { float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); } TIMER_STOP(output_store); } @@ -1475,177 +1636,431 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif - return 0; + return 0; } -// - int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } +} - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - return -1; +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); } +} - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + const HVX_Vector one = hvx_vec_splat_f16(1.0); - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - TIMER_DEFINE(total); - TIMER_START(total); + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + if (r_idx0 >= cne1) break; - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); } } - TIMER_STOP(activation_load); + } +} - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } + if (k % 32 != 0 || n % 32 != 0) { return -1; } - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + const int num_threads = ctx->n_threads; - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - TIMER_STOP(total); + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); + + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); + } } -#endif + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 00000000000..114d8c14811 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index f114edb822f..a67842f3ffc 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx, // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_q_f32(struct htp_context *ctx, +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, int m, int k, int n, + int act_stride, + int weight_stride, int weight_type); +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + // HMX flash attention int hmx_flash_attn_ext(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 51f9243ce0a..0f1676f077a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -79,6 +79,10 @@ struct htp_context { uint64_t max_vmem; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 00000000000..f1f2e49e455 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include <math.h> +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 00000000000..7013dae785a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 00000000000..48fe0e8eade --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include <math.h> +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 0a760cd344c..23373f73ae2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -17,5 +17,7 @@ #include "hvx-floor.h" #include "hvx-sin-cos.h" #include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 623008be4e2..3715227d2c7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -12,6 +12,7 @@ #include <HAP_mem.h> #include <HAP_power.h> #include <HAP_ps.h> +#include <HAP_dcvs.h> #include <qurt.h> #include <qurt_thread.h> #include <qurt_memory.h> @@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Setting HMX clock\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "ggml-hex: error setting HMX clock."); return err; } } @@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Powering HMX on\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } @@ -423,10 +427,18 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -474,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 7036c491bc4..5121c6f9bad 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -53,6 +53,11 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -2913,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} + +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} + +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); +} + static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3331,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3466,7 +3641,7 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3516,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; @@ -3542,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -3583,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3685,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -4086,6 +4262,47 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -4328,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; + + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; + + src1_row_size = nb11; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; } } else { @@ -4405,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { return op_matmul_hvx(octx); } - if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { return op_matmul_hvx(octx); } @@ -4463,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (src0->type == HTP_TYPE_F16) { - if (is_batched) { + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, @@ -4488,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) { }; ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_matmul_f16_f32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_total, k, n, act_stride, wgt_stride); + return op_matmul_hvx(octx); } } else { - ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, (int) src0->type); + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); } if (ret != 0) { @@ -4539,8 +4808,30 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } + + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); + + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } @@ -4552,7 +4843,7 @@ int op_matmul_id(struct htp_ops_context * octx) { src1_row_size = q8x4x2_row_size(ne10); } - const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; @@ -4568,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -4587,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) { if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix @@ -4599,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) { assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -4618,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_matmul_jobs = octx->n_threads; worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c index 3abc3c2ead1..aaa72b31590 100644 --- a/ggml/src/ggml-hexagon/htp/pad-ops.c +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) { octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; } struct htp_pad_context pctx = { diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 770a6673211..71fab2cdbcb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -692,6 +692,11 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -738,10 +743,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); } ir += block_size; @@ -823,10 +828,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad, data_src1 + src1_pref_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); } } } @@ -977,6 +982,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; } + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); From d31cb20b258f335df1a658d0dc2d9a271110d14a Mon Sep 17 00:00:00 2001 From: Max Krasnyansky <maxk@qti.qualcomm.com> Date: Tue, 2 Jun 2026 14:08:29 -0700 Subject: [PATCH 236/289] hexagon: profiler output fix and script updates (llama/24042) * hex-ops: fix profiler output (ie remove the redundant NONEs) * hex-prof: update profiling script to support tot.usec column --- ggml/src/ggml-hexagon/htp-opnode.h | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 8a1228ccdc0..52c727c6206 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,6 +56,20 @@ struct htp_opnode { } std::vector<const ggml_tensor *> get_inputs() const { + if (fused.empty()) { + int last_non_null = -1; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + last_non_null = i; + } + } + std::vector<const ggml_tensor *> inputs(last_non_null + 1, nullptr); + for (int i = 0; i <= last_non_null; i++) { + inputs[i] = node->src[i]; + } + return inputs; + } + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); std::vector<const ggml_tensor *> outputs; outputs.push_back(node); @@ -82,12 +96,8 @@ struct htp_opnode { }; for (int i = 0; i < GGML_MAX_SRC; i++) { - if (fused.empty()) { - inputs[i] = node->src[i]; - } else { - if (node->src[i]) { - add_input(node->src[i]); - } + if (node->src[i]) { + add_input(node->src[i]); } } for (const auto * f : fused) { @@ -98,10 +108,7 @@ struct htp_opnode { } } - if (!fused.empty()) { - inputs.resize(count); - } - + inputs.resize(count); return inputs; } From f110ff540c06ba74a9f26c9964eb21c1edf846b1 Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Tue, 2 Jun 2026 14:16:17 -0700 Subject: [PATCH 237/289] opencl: use flat variants of q4_K and q6_K gemv for very large M (llama/24006) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 31 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b67ea46bce8..c411e4aeaec 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4950,6 +4950,21 @@ inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backen return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 } +static inline bool use_flat_gemv_for_large_m_q4_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool use_flat_gemv_for_large_m_q6_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // q6_K flat gemv is worse for smaller K; 2048 seems to be a reasonable threshold. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[0] >= 2048 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -6595,7 +6610,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } #else @@ -6623,7 +6638,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6923,7 +6938,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_convert_block_q6_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; } #else @@ -6956,7 +6971,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -7599,7 +7614,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -7820,7 +7835,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; static ggml_cl_buffer buf_trans_s; @@ -13213,13 +13228,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q4_K(src0)) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; } // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q6_K(src0)) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); return; } From d5a49ebec8445172aa73d36c3cefb0586e7ce1ab Mon Sep 17 00:00:00 2001 From: Aman Gupta <amangupta052@gmail.com> Date: Wed, 3 Jun 2026 18:39:59 +0800 Subject: [PATCH 238/289] cuda: reserve space for quantize kv-cache at startup (llama/23907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cuda: reserve space for quantize kv-cache at startup * address review comments * remove forward decl Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * remove assert in ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --- ggml/src/ggml-cuda/fattn-common.cuh | 65 ++++++++++++++++++++++++----- ggml/src/ggml-cuda/fattn.cu | 35 ++++++++++++++++ ggml/src/ggml-cuda/fattn.cuh | 2 + ggml/src/ggml-cuda/ggml-cuda.cu | 8 ++-- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d650b5fbd0f..064f753f7ef 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)( typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +struct ggml_cuda_flash_attn_ext_f16_extra_data { + uintptr_t K; + uintptr_t V; + uintptr_t end; +}; + +static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data( + const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + ggml_cuda_flash_attn_ext_f16_extra_data data = {}; + data.end = (uintptr_t) dst->data + ggml_nbytes(dst); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + data.end = GGML_PAD(data.end, 128); + data.K = data.end; + data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16); + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + data.V = data.K; + } else { + data.end = GGML_PAD(data.end, 128); + data.V = data.end; + data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16); + } + } + + return data; +} + template <int D, int nthreads> static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -952,8 +992,9 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc<half> K_f16(pool); - ggml_cuda_pool_alloc<half> V_f16(pool); + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); + ggml_cuda_pool_alloc<int> KV_max(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); @@ -972,10 +1013,11 @@ void launch_fattn( const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - K_f16.alloc(ggml_nelements(K)); + GGML_ASSERT(f16_extra.K != 0); + half * K_f16 = (half *) f16_extra.K; if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; @@ -986,13 +1028,13 @@ void launch_fattn( const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; - to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } - K_data = (char *) K_f16.ptr; + K_data = (char *) K_f16; } if (need_f16_V && V->type != GGML_TYPE_F16) { @@ -1005,11 +1047,12 @@ void launch_fattn( const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - V_f16.alloc(ggml_nelements(V)); + GGML_ASSERT(f16_extra.V != 0); + half * V_f16 = (half *) f16_extra.V; if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; + to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); + V_data = (char *) V_f16; nb21 = nb21*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts; @@ -1020,13 +1063,13 @@ void launch_fattn( const int64_t s01 = nb21 / ts; const int64_t s02 = nb22 / ts; const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); nb21 = V->ne[0] * sizeof(half); nb22 = V->ne[1] * nb21; nb23 = V->ne[2] * nb22; } - V_data = (char *) V_f16.ptr; + V_data = (char *) V_f16; } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1c7777e8a71..d6c501b1d7e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); + + bool need_f16_K = false; + bool need_f16_V = false; + + switch (kernel) { + case BEST_FATTN_KERNEL_TILE: + case BEST_FATTN_KERNEL_WMMA_F16: + case BEST_FATTN_KERNEL_MMA_F16: + need_f16_K = true; + need_f16_V = true; + break; + case BEST_FATTN_KERNEL_VEC: + need_f16_K = K->type == GGML_TYPE_F32; + need_f16_V = V->type == GGML_TYPE_F32; + break; + case BEST_FATTN_KERNEL_NONE: + break; + } + + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V); + + return f16_extra.end - (uintptr_t) dst->data; +} + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index 78705d59951..f9a7e15fbd6 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -3,3 +3,5 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 18aaa098398..f5293ad4cbb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty } static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context; + + size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT + ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor) + : ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; if (ggml_is_quantized(tensor->type)) { @@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t } return size; - - GGML_UNUSED(buft); } static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { From 750fa4ca35c4b39a487724f854bfc788a2d9db53 Mon Sep 17 00:00:00 2001 From: Charles Xu <charles.xu@arm.com> Date: Wed, 3 Jun 2026 12:45:10 +0200 Subject: [PATCH 239/289] ggml-cpu: use runtime SVE width in FWHT (llama/24059) --- ggml/src/ggml-cpu/ops.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index dc73696ad9f..3a1912ae91b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8955,7 +8955,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( k->type == v->type && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD - use_tiled &= (DV % GGML_F32_EPR == 0); +#if defined(__ARM_FEATURE_SVE) + const int64_t f32_epr = svcntw(); +#else + const int64_t f32_epr = GGML_F32_EPR; +#endif + use_tiled &= (DV % f32_epr == 0); #endif int current_chunk = ith; @@ -11358,7 +11363,11 @@ static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, gg // Scalar passes #if defined(GGML_SIMD) +#if defined(__ARM_FEATURE_SVE) + const int step = svcntw(); +#else const int step = GGML_F32_EPR; +#endif #else const int step = n; #endif From 00a9728de303c399f069dd1b0b7a33689ab3e56a Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:56:42 +0200 Subject: [PATCH 240/289] Avoid PDL race conditions by disabling __restrict__ when PDL is used (llama/24030) * Removes __restrict__ from PDL kernel headers due to incompatibility with PDL. Adds preprocessor directives based on arch in kernel body to add __restrict__ to retain performance on older architectures. * Simplifies new __restrict__ usage via macro * Add hopper to PDL __restrict__ fix. Co-authored-by: Oliver Simons <osimons@nvidia.com> --------- Co-authored-by: Oliver Simons <osimons@nvidia.com> --- ggml/src/ggml-cuda/common.cuh | 6 ++++++ ggml/src/ggml-cuda/fattn-common.cuh | 25 ++++++++++++++++--------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-tile.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-vec.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/getrows.cu | 5 ++++- ggml/src/ggml-cuda/mmvf.cu | 6 +++++- ggml/src/ggml-cuda/mmvq.cu | 6 +++++- ggml/src/ggml-cuda/quantize.cu | 4 +++- ggml/src/ggml-cuda/reduce_rows.cuh | 4 +++- ggml/src/ggml-cuda/set-rows.cu | 9 ++++++--- ggml/src/ggml-cuda/ssm-conv.cu | 10 +++++++--- ggml/src/ggml-cuda/ssm-scan.cu | 28 ++++++++++++++++++++++------ 14 files changed, 145 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 560fab0b17b..e6e50e04119 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1611,6 +1611,12 @@ static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { #endif //defined(GGML_CUDA_USE_PDL) +// PDL and __restrict__ need to be mutually exclusive, see https://github.com/ggml-org/llama.cpp/pull/24030 +# if (defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER) +# define GGML_CUDA_RESTRICT +# else +# define GGML_CUDA_RESTRICT __restrict__ +# endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER template<typename Kernel, typename... Args> static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 064f753f7ef..8dfa51ad1e8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -718,8 +718,8 @@ static __global__ void flash_attn_mask_to_KV_max( template<int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_uniform( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int ne12, const int nblocks_stream_k, const int gqa_ratio, @@ -729,6 +729,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; ggml_cuda_pdl_lc(); + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -800,8 +802,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( template <int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_general( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int gqa_ratio, const int total_work, @@ -809,6 +811,8 @@ static __global__ void flash_attn_stream_k_fixup_general( const uint3 fd_iter_k_j_z, const uint3 fd_iter_k_j, const uint3 fd_iter_k) { + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -907,11 +911,14 @@ static __global__ void flash_attn_stream_k_fixup_general( template<int D> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, + const float * VKQ_parts_ptr, + const float2 * VKQ_meta_ptr, + float * dst_ptr, const int parallel_blocks) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT VKQ_parts = VKQ_parts_ptr; + const float2 * GGML_CUDA_RESTRICT VKQ_meta = VKQ_meta_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -1196,8 +1203,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - // disabled PDL enrollment for now due to a compiler bug. - fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( + ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index ac5abb13367..83478a02cb6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1703,14 +1703,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view> __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -1726,6 +1726,14 @@ static __global__ void flash_attn_ext_f16( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1871,7 +1879,7 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index fac76f13593..0a099810e14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -788,14 +788,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -810,6 +810,14 @@ static __global__ void flash_attn_tile( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: @@ -1126,7 +1134,7 @@ static __global__ void flash_attn_tile( } } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index b0a6cf67f1a..69dd9368624 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -19,14 +19,14 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) static __global__ void flash_attn_ext_vec( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -42,6 +42,14 @@ static __global__ void flash_attn_ext_vec( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -506,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 4b6f6501094..6850716fc0d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -24,14 +24,14 @@ namespace wmma = rocwmma; template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -46,6 +46,14 @@ static __global__ void flash_attn_ext_f16( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -494,7 +502,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 457b695eb2a..eb157b8baf2 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -42,7 +42,7 @@ static __global__ void k_get_rows( template<typename src0_t, typename dst_t> static __global__ void k_get_rows_float( - const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const src0_t * src0_ptr, const int32_t * src1_ptr, dst_t * dst_ptr, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, @@ -50,6 +50,9 @@ static __global__ void k_get_rows_float( const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { ggml_cuda_pdl_lc(); + const src0_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const int32_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 3d6de64b775..d7dbc8b9928 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -6,11 +6,15 @@ template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false> static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const T * x_ptr, const float * y_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride) { + const T * GGML_CUDA_RESTRICT x = x_ptr; + const float * GGML_CUDA_RESTRICT y = y_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 86b4a493019..4b0426590ac 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -476,12 +476,16 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false> __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 49516965cad..39a500a1704 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,10 +3,12 @@ __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( - const float * __restrict__ x, void * __restrict__ vy, + const float * x_ptr, void * vy_ptr, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT x = x_ptr; + void * GGML_CUDA_RESTRICT vy = vy_ptr; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 5895d3bf8e5..968c47aa20a 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -2,7 +2,9 @@ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) template <bool norm> -static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) { + const float * GGML_CUDA_RESTRICT x = x_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; const int col = threadIdx.x; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index e14f96b824c..3b4f004c946 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -111,9 +111,9 @@ static void set_rows_cuda_quant( } template <typename src_t, typename idx_t, typename dst_t> -static __global__ void k_set_rows(const src_t * __restrict__ src0, - const idx_t * __restrict__ src1, - dst_t * __restrict__ dst, +static __global__ void k_set_rows(const src_t * src0_ptr, + const idx_t * src1_ptr, + dst_t * dst_ptr, const int64_t ne_total, const int64_t ne10, const int64_t ne11, @@ -133,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const uint3 ne02, const uint3 ne11_fd, const uint3 ne12_fd) { + const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; if (i >= ne_total) { diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 48787b4b890..1463169cf78 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,12 +3,16 @@ #include "unary.cuh" template <bool apply_silu, size_t split_d_inner, size_t d_conv> -static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, - const float * __restrict__ bias, +static __global__ void ssm_conv_f32(const float * src0_ptr, const float * src1_ptr, + const float * bias_ptr, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, - float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * dst_ptr, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT bias = bias_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 412980376ac..2e3f97c7284 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -17,14 +17,22 @@ using namespace cub; #endif // __clang__ template <size_t splitD, size_t N, size_t L_template> __global__ void __launch_bounds__(splitD, 1) - ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, - const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + ssm_scan_f32(const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L_param) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const size_t L = L_template == 0 ? L_param : L_template; ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); @@ -118,13 +126,21 @@ __global__ void __launch_bounds__(splitD, 1) template <int c_factor, int d_state> __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; From a1a31868870f0900940d27b5c4d426a9938731d4 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer <rehanbackup0317@gmail.com> Date: Thu, 4 Jun 2026 10:03:40 +0500 Subject: [PATCH 241/289] ggml-cpu: extend RVV quantization vec dot to higher VLENs (llama/22754) * ggml-cpu: add rvv 512b,1024b impls for iq4_xs * ggml-cpu: refactor; add rvv 512b, 1024b impls for q6_K, i-quants * ggml-cpu: refactor; add 512 and 1024 implementations of tq3_s, iq3_xxs, iq2_s, iq2_xs, iq2_xxs improve iq2_xs impl for rvv 256 Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai> --------- Co-authored-by: taimur-10x <taimur.ahmad@10xengineers.ai> Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai> --- ggml/src/ggml-cpu/arch/riscv/quants.c | 3895 +++++++++++++++++++------ 1 file changed, 2969 insertions(+), 926 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ee69e5ab5e5..47e9180bf9b 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -123,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined __riscv_v_intrinsic +#if defined __riscv_v block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -578,7 +578,8 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q2_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -590,8 +591,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -686,246 +685,281 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q2_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 16; + const int nb = n / QK_K; - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + float sumf = 0; + uint8_t atmp[16]; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - vl = 32; + size_t vl = 16; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - uint8_t is = 0; - int isum = 0; + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + vl = 32; - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + uint8_t is = 0; + int isum = 0; - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - q2 += 32; - q8 += 128; - is = 8; - } + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - sumf += dall * isum; + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; +} +#endif +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q2_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q2_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q2_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q3_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -941,8 +975,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - uint32_t utmp[4]; float sumf = 0; @@ -1068,257 +1100,274 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q3_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; uint32_t utmp[4]; float sumf = 0; uint32_t aux[3]; - const int vector_length = __riscv_vlenb() * 8; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + *s = sumf; +} - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); +void ggml_vec_dot_q3_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 32; - uint8_t m = 1; + const int nb = n / QK_K; + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - int sum_t = 0; + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - for (int j = 0; j < QK_K; j += 128) { + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - vl = 32; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + size_t vl = 32; + uint8_t m = 1; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + int sum_t = 0; - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + for (int j = 0; j < QK_K; j += 128) { - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vl = 32; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - vl = 16; + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - q3 += 32; q8 += 128; scale += 8; + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - } + vl = 16; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - sumf += d*sum_t; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - int8_t * scale = (int8_t *)utmp; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "lb zero, 31(%[q3])\n\t" - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "lb zero, 64(%[q8])\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "lb zero, 127(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "lb %[tmp], 0(%[scale])\n\t" - "lb %[t1], 1(%[scale])\n\t" - "lb %[t2], 2(%[scale])\n\t" - "lb %[t3], 3(%[scale])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "lb %[t4], 4(%[scale])\n\t" - "lb %[t5], 5(%[scale])\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "lb %[t6], 6(%[scale])\n\t" - "lb %[t7], 7(%[scale])\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } + q3 += 32; q8 += 128; scale += 8; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - *s = sumf; - -#else + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(x); - UNUSED(y); - UNUSED(nb); + sumf += d*sum_t; - ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif + } + *s = sumf; } -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_q3_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1326,27 +1375,289 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi UNUSED(by); UNUSED(bs); - const block_q4_K * GGML_RESTRICT x = vx; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; const block_q8_K * GGML_RESTRICT y = vy; const int nb = n / QK_K; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); uint32_t utmp[4]; - -#if defined __riscv_xtheadvector - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - float sumf = 0; + uint32_t aux[3]; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(qh, vl); + + int sum_t = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x03, vl)); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf2_t qh_m0 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_0 = __riscv_vmseq_vx_u8mf2_b16(qh_m0, 0, vl); + vint8mf2_t q3_m0 = __riscv_vsub_vx_i8mf2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m1 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_1 = __riscv_vmseq_vx_u8mf2_b16(qh_m1, 0, vl); + vint8mf2_t q3_m1 = __riscv_vsub_vx_i8mf2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m2 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_2 = __riscv_vmseq_vx_u8mf2_b16(qh_m2, 0, vl); + vint8mf2_t q3_m2 = __riscv_vsub_vx_i8mf2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m3 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_3 = __riscv_vmseq_vx_u8mf2_b16(qh_m3, 0, vl); + vint8mf2_t q3_m3 = __riscv_vsub_vx_i8mf2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(q3_m0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(q3_m1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(q3_m2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(q3_m3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf4_t vqh = __riscv_vle8_v_u8mf4(qh, vl); + + int sum_t = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf4_t q3_x = __riscv_vle8_v_u8mf4(q3, vl); + + vint8mf4_t q3_0 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(q3_x, 0x03, vl)); + vint8mf4_t q3_1 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf4_t q3_2 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf4_t q3_3 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf4_t qh_m0 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_0 = __riscv_vmseq_vx_u8mf4_b32(qh_m0, 0, vl); + vint8mf4_t q3_m0 = __riscv_vsub_vx_i8mf4_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m1 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_1 = __riscv_vmseq_vx_u8mf4_b32(qh_m1, 0, vl); + vint8mf4_t q3_m1 = __riscv_vsub_vx_i8mf4_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m2 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_2 = __riscv_vmseq_vx_u8mf4_b32(qh_m2, 0, vl); + vint8mf4_t q3_m2 = __riscv_vsub_vx_i8mf4_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m3 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_3 = __riscv_vmseq_vx_u8mf4_b32(qh_m3, 0, vl); + vint8mf4_t q3_m3 = __riscv_vsub_vx_i8mf4_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(q3_m0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(q3_m1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(q3_m2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(q3_m3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q3_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q3_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q3_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q3_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q3_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); int tmp, tmp2, sumi; __asm__ __volatile__( @@ -1452,277 +1763,317 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; + + __asm__ __volatile__( + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1, ta, ma\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "addi %[s0], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + *s = sumf; +} - size_t vl = 8; +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + const int nb = n / QK_K; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + uint32_t utmp[4]; - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + size_t vl = 8; - vl = 32; + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vl = 32; - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + int32_t sum_1 = 0; + int32_t sum_2 = 0; - q4 += 32; q8 += 64; + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - } + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - sumf += d*(sum_1 + sum_2); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - - float ftmp, ft2; - const uint8_t * restrict q40; - const uint8_t * restrict q41; - const uint8_t * restrict q42; - const uint8_t * restrict q43; - const int8_t * restrict q80; - const int8_t * restrict q81; - const int8_t * restrict q82; - const int8_t * restrict q83; - int s0, s1, s2, s3; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; - __asm__ __volatile__( - "li %[s1], 8\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vle32.v v1, (%[s6b])\n\t" - "vslide1down.vx v1, v1, zero\n\t" - "vmv.v.x v16, zero\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1, ta, ma\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vsse32.v v8, (%[utmp]), %[s1]\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "addi %[s0], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v1, (%[s0]), %[s1]\n\t" - "vsetivli zero, 8, e16, m1, ta, ma\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vredsum.vs v0, v6, v16\n\t" - "vredsum.vs v0, v7, v0\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "vsetivli zero, 16, e8, m1, ta, ma\n\t" - "vle8.v v0, (%[xs])\n\t" - "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" - "addi %[q40], %[xs], 64\n\t" - "addi %[q41], %[xs], 16\n\t" - "addi %[q42], %[xs], 32\n\t" - "addi %[q43], %[xs], 48\n\t" - "addi %[q80], %[ys], 64\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "addi %[q81], %[ys], 16\n\t" - "addi %[q41], %[q41], 64\n\t" - "addi %[q82], %[ys], 32\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[ys])\n\t" - "addi %[q42], %[q42], 64\n\t" - "addi %[q83], %[ys], 48\n\t" - "addi %[q43], %[q43], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v5, v1, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q80])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vsrl.vi v6, v2, 4\n\t" - "addi %[q80], %[q80], 64\n\t" - "vle8.v v13, (%[q81])\n\t" - "vle8.v v14, (%[q82])\n\t" - "vand.vi v2, v2, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v7, v3, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vle8.v v15, (%[q83])\n\t" - "vle8.v v0, (%[q40])\n\t" - "vand.vi v3, v3, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmacc.vv v16, v1, v9\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "vwmacc.vv v24, v3, v13\n\t" - "vwmacc.vv v20, v5, v11\n\t" - "vwmacc.vv v28, v7, v15\n\t" - "addi %[q40], %[q80], 64\n\t" - "addi %[q41], %[q81], 64\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[q80])\n\t" - "addi %[q42], %[q82], 64\n\t" - "addi %[q43], %[q83], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v7, v3, 4\n\t" - "vand.vi v3, v3, 0xF\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q40])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "vsrl.vi v6, v2, 4\n\t" - "vand.vi v2, v2, 0xF\n\t" - "vwmul.vv v18, v0, v8\n\t" - "vle8.v v13, (%[q41])\n\t" - "vle8.v v14, (%[q42])\n\t" - "vwmul.vv v26, v2, v12\n\t" - "vwmul.vv v22, v4, v10\n\t" - "vwmul.vv v30, v6, v14\n\t" - "vwmacc.vv v18, v1, v9\n\t" - "vle8.v v15, (%[q43])\n\t" - "vwmacc.vv v26, v3, v13\n\t" - "vwmacc.vv v22, v5, v11\n\t" - "vwmacc.vv v30, v7, v15\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 16, e16, m2, ta, ma\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "lbu %[s0], 0(%[scale])\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "lbu %[s1], 1(%[scale])\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "lbu %[s2], 2(%[scale])\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "lbu %[s3], 3(%[scale])\n\t" - "vwredsum.vs v8, v18, v0\n\t" - "lbu %[q40], 4(%[scale])\n\t" - "vwredsum.vs v9, v22, v0\n\t" - "lbu %[q41], 5(%[scale])\n\t" - "vwredsum.vs v10, v26, v0\n\t" - "lbu %[q42], 6(%[scale])\n\t" - "vwredsum.vs v11, v30, v0\n\t" - "lbu %[q43], 7(%[scale])\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vmul.vx v0, v4, %[s0]\n\t" - "vmul.vx v1, v8, %[q40]\n\t" - "vmacc.vx v0, %[s1], v5\n\t" - "vmacc.vx v1, %[q41], v9\n\t" - "vmacc.vx v0, %[s2], v6\n\t" - "vmacc.vx v1, %[q42], v10\n\t" - "vmacc.vx v0, %[s3], v7\n\t" - "vmacc.vx v1, %[q43], v11\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfcvt.f.x.v v1, v1\n\t" - "vfmv.f.s %[ft2], v0\n\t" - "vfmv.f.s %[ftmp], v1\n\t" - "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" - "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" - : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) - , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) - , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) - , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) - : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) - , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += d*(sum_1 + sum_2); + } *s = sumf; +} +#endif +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(nb); - UNUSED(utmp); - ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -1823,7 +2174,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; - } sums += aux32 * d; @@ -1846,7 +2196,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1859,8 +2210,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -1939,224 +2288,462 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int q6h; + float ftmp; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30), [d] "f" (d) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" + ); + qh += 32; q8 += 128; + } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * GGML_RESTRICT scale = x[i].scales; - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + size_t vl; - const int8_t * GGML_RESTRICT scale = x[i].scales; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - size_t vl; + int sum_t = 0; + int is = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - int sum_t = 0; - int is = 0; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+32, vl); + + vuint8mf2_t q6a_0 = __riscv_vand_vx_u8mf2(q6_0, 0x0F, vl); + vuint8mf2_t q6a_1 = __riscv_vand_vx_u8mf2(q6_1, 0x0F, vl); + vuint8mf2_t q6s_0 = __riscv_vsrl_vx_u8mf2(q6_0, 0x04, vl); + vuint8mf2_t q6s_1 = __riscv_vsrl_vx_u8mf2(q6_1, 0x04, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(qh_x, 0x03, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf2_t qhi_0 = __riscv_vor_vv_u8mf2(q6a_0, __riscv_vsll_vx_u8mf2(qh_0, 0x04, vl), vl); + vuint8mf2_t qhi_1 = __riscv_vor_vv_u8mf2(q6a_1, __riscv_vsll_vx_u8mf2(qh_1, 0x04, vl), vl); + vuint8mf2_t qhi_2 = __riscv_vor_vv_u8mf2(q6s_0, __riscv_vsll_vx_u8mf2(qh_2, 0x04, vl), vl); + vuint8mf2_t qhi_3 = __riscv_vor_vv_u8mf2(q6s_1, __riscv_vsll_vx_u8mf2(qh_3, 0x04, vl), vl); + + vint8mf2_t a_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_0), 32, vl); + vint8mf2_t a_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_1), 32, vl); + vint8mf2_t a_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_2), 32, vl); + vint8mf2_t a_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(a_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(a_1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(a_2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(a_3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - for (int j = 0; j < QK_K/128; ++j) { + const int nb = n / QK_K; - vl = 32; + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + float sumf = 0; - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + const int8_t * GGML_RESTRICT scale = x[i].scales; - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + size_t vl = 32; - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + int sum_t = 0; + int is = 0; - vl = 16; + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf4_t qh_x = __riscv_vle8_v_u8mf4(qh, vl); + + // load Q6 + vuint8mf4_t q6_0 = __riscv_vle8_v_u8mf4(q6, vl); + vuint8mf4_t q6_1 = __riscv_vle8_v_u8mf4(q6+32, vl); + + vuint8mf4_t q6a_0 = __riscv_vand_vx_u8mf4(q6_0, 0x0F, vl); + vuint8mf4_t q6a_1 = __riscv_vand_vx_u8mf4(q6_1, 0x0F, vl); + vuint8mf4_t q6s_0 = __riscv_vsrl_vx_u8mf4(q6_0, 0x04, vl); + vuint8mf4_t q6s_1 = __riscv_vsrl_vx_u8mf4(q6_1, 0x04, vl); + + vuint8mf4_t qh_0 = __riscv_vand_vx_u8mf4(qh_x, 0x03, vl); + vuint8mf4_t qh_1 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf4_t qh_2 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf4_t qh_3 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf4_t qhi_0 = __riscv_vor_vv_u8mf4(q6a_0, __riscv_vsll_vx_u8mf4(qh_0, 0x04, vl), vl); + vuint8mf4_t qhi_1 = __riscv_vor_vv_u8mf4(q6a_1, __riscv_vsll_vx_u8mf4(qh_1, 0x04, vl), vl); + vuint8mf4_t qhi_2 = __riscv_vor_vv_u8mf4(q6s_0, __riscv_vsll_vx_u8mf4(qh_2, 0x04, vl), vl); + vuint8mf4_t qhi_3 = __riscv_vor_vv_u8mf4(q6s_1, __riscv_vsll_vx_u8mf4(qh_3, 0x04, vl), vl); + + vint8mf4_t a_0 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_0), 32, vl); + vint8mf4_t a_1 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_1), 32, vl); + vint8mf4_t a_2 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_2), 32, vl); + vint8mf4_t a_3 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_3), 32, vl); + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(a_0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(a_1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(a_2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(a_3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + q6 += 64; qh += 32; q8 += 128; is=8; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + } - q6 += 64; qh += 32; q8 += 128; is=8; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); - } + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); - sumf += d * sum_t; + sumf += d * sum_t; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - - __builtin_prefetch(&x[i + 1].d, 0, 1); - - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int q6h; - float ftmp; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "addi %[q6h], %[q6], 32\n\t" - "ld t0, 0(%[scale])\n\t" - "addi %[scale], %[scale], 8\n\t" - "slli t6, t0, 1 * 8\n\t" - "lb zero, 0(%[q6])\n\t" - "slli t5, t0, 2 * 8\n\t" - "slli t4, t0, 3 * 8\n\t" - "lb zero, 0(%[q6h])\n\t" - "slli t3, t0, 4 * 8\n\t" - "slli t2, t0, 5 * 8\n\t" - "lb zero, 0(%[qh])\n\t" - "lb zero, 31(%[q6h])\n\t" - "slli t1, t0, 6 * 8\n\t" - "srai a7, t0, 56\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v8, (%[q6])\n\t" - "srai t6, t6, 56\n\t" - "srai t5, t5, 56\n\t" - "srai t4, t4, 56\n\t" - "srai t3, t3, 56\n\t" - "vle8.v v10, (%[q6h])\n\t" - "addi %[q6], %[q6], 64\n\t" - "slli t0, t0, 7 * 8\n\t" - "srai t2, t2, 56\n\t" - "srai t1, t1, 56\n\t" - "srai t0, t0, 56\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v10, 4\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vand.vi v10, v10, 0xF\n\t" - "lb zero, 32(%[q8])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "lb zero, 64(%[q8])\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "lb zero, 96(%[q8])\n\t" - "vand.vx v2, v2, %[mask]\n\t" - "vand.vx v4, v4, %[mask]\n\t" - "vand.vx v6, v6, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "lb zero, 127(%[q8])\n\t" - "vor.vv v10, v10, v2\n\t" - "vor.vv v12, v12, v4\n\t" - "vor.vv v14, v14, v6\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v10, t0\n\t" - "vmul.vx v1, v9, t1\n\t" - "vmacc.vx v0, t2, v8\n\t" - "vmacc.vx v1, t3, v7\n\t" - "vmacc.vx v0, t4, v11\n\t" - "vmacc.vx v1, t5, v12\n\t" - "vmacc.vx v0, t6, v13\n\t" - "vmacc.vx v1, a7, v14\n\t" - "vadd.vv v0, v0, v1\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" - : [q6] "+&r" (q6), [q6h] "=&r" (q6h) - , [scale] "+&r" (scale) - , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) - : [qh] "r" (qh), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30), [d] "f" (d) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" - , "a6", "a5", "a4", "a3" - ); - qh += 32; q8 += 128; - } - } - break; - default: - assert(false && "Unsupported vector length"); - break; } *s = sumf; +} +#endif +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q6_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q6_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q6_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2364,10 +2951,190 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16m1_t index = __riscv_vsll_vx_u16m1(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, index, 32)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 256); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 256); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 4), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 5), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 6), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 7), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // Mask for processing 32 elements per lsum register. + vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16mf2_t index = __riscv_vlmul_trunc_v_u16m1_u16mf2(__riscv_vsll_vx_u16m1(qh_index, 3, 32)); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m2_t grid0 = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vluxei16_v_i64m2((const int64_t*)iq1s_grid, index, 32)); + vint8m2_t q80 = __riscv_vle8_v_i8m2(y[i].qs, 256); + vint16m4_t lsum0 = __riscv_vwmul_vv_i16m4(grid0, q80, 256); + + // Reduce. + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 64)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 64)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 64)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 64)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2375,6 +3142,12 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2384,7 +3157,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2664,10 +3437,287 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(32); + const vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 32); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 32); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 16); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 16); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 16); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 32), 32); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 16)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16m1(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 16)), 32), 0, 32); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 32); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 256); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_1, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_1, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_2, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_3, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_2, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_3, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_4, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_5, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_4, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_5, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_6, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_7, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_6, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_7, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_8, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_9, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_8, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_9, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_10, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_11, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_10, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_11, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_12, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_13, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_12, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_13, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_14, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_15, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_14, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_15, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 32); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 32)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 32)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 64); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 64); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf8_t qh_8 = __riscv_vle8_v_u8mf8(qh, 16); + const vuint16mf4_t qh_16_lo = __riscv_vzext_vf2_u16mf4(qh_8, 16); + const vuint16mf4_t qh_16_hi = __riscv_vsll_vx_u16mf4(qh_16_lo, 8, 16); + const vuint16mf2_t qhb = __riscv_vzext_vf2_u16mf2( + __riscv_vreinterpret_v_u16mf4_u8mf4(__riscv_vor_vv_u16mf4(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16mf2_t qsb = __riscv_vzext_vf2_u16mf2(__riscv_vle8_v_u8mf4(&qs[0], 32), 32); + const vuint16mf2_t shift = __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00040008, 16)); + vuint16mf2_t index = __riscv_vor_vv_u16mf2(qsb, __riscv_vand_vx_u16mf2(__riscv_vsll_vv_u16mf2(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16mf2(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m2_t iq1b = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vreinterpret_v_u64m2_i64m2( + __riscv_vluxei16_v_u64m2(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool32_t mask = __riscv_vmsgtu_vx_u16mf2_b32( + __riscv_vand_vv_u16mf2(qhb, __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00800008, 16)), 32), 0, 32); + const vint64m2_t delta_pos = __riscv_vmv_v_x_i64m2(0x0101010101010101, 32); + const vint8m2_t delta = __riscv_vreinterpret_v_i64m2_i8m2( + __riscv_vmerge_vxm_i64m2(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(iq1b, q8b, 256); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + + // Accumulate in acc1 and acc2 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_4, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_4, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_8, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_8, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_12, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_12, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 16); + // + const vbool16_t l_mask_16_32 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_1, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_1, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_5, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_5, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_9, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_9, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_13, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_13, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 32); + // + const vbool16_t l_mask_32_48 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_2, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_2, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_6, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_6, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_10, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_10, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_14, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_14, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 48); + // + const vbool16_t l_mask_48_64 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 47, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_3, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_3, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_7, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_7, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_11, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_11, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_15, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_15, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 64); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 64)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 64)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2675,6 +3725,12 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_m_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_m_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2684,7 +3740,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2887,10 +3943,275 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + vuint8m2_t v_ids = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 128); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 128); + + uint16_t gather_qh_arr[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 16); + + uint16_t shift_qh_arr[16] = {11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 16); + + // Masks for selecting lower/upper 16 lanes within a 32-lane i16m1 register + vuint16m1_t v_ids16 = __riscv_vid_v_u16m1(32); + vbool16_t m_hi16 = __riscv_vmsgeu_vx_u16m1_b16(v_ids16, 16, 32); + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 2; ++ib) { + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 16); + qs += 16; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 4); + qh += 4; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 4); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 16); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 16); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 16); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 16); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 16); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 16); + signs_ptr += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 128); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m1_t v0 = __riscv_vget_v_i16m4_i16m1(v_dot, 0); + vint16m1_t v1 = __riscv_vget_v_i16m4_i16m1(v_dot, 1); + vint16m1_t v2 = __riscv_vget_v_i16m4_i16m1(v_dot, 2); + vint16m1_t v3 = __riscv_vget_v_i16m4_i16m1(v_dot, 3); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v0, v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v0, v_zero, 32)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v1, v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v1, v_zero, 32)); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v2, v_zero, 16)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v2, v_zero, 32)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( v3, v_zero, 16)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v3, v_zero, 32)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + sum_block += s4 * (2 * (sc2 & 0xF) + 1); + sum_block += s5 * (2 * (sc2 >> 4) + 1); + sum_block += s6 * (2 * (sc3 & 0xF) + 1); + sum_block += s7 * (2 * (sc3 >> 4) + 1); + } + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + vuint8m2_t v_ids = __riscv_vid_v_u8m2(256); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 256); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 256); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 256); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 256); + + uint16_t gather_qh_arr[32] = { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7 + }; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 32); + + uint16_t shift_qh_arr[32] = { + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5 + }; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 32); + + // Masks for 4 groups of 16 lanes within a 64-lane i16m4 chunk + vuint16m4_t v_ids64 = __riscv_vid_v_u16m4(64); + vbool4_t m_g0 = __riscv_vmsltu_vx_u16m4_b4(v_ids64, 16, 64); + vbool4_t m_g1 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 16, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 32, 64), 64); + vbool4_t m_g2 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 32, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 48, 64), 64); + vbool4_t m_g3 = __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 48, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 32); + qs += 32; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 8); + qh += 8; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 8); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 32); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 32); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 32); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 32); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 32); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + //loading signs + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs_ptr, 32); + signs_ptr += 32; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 256); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 256); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 256); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + q8 += 256; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 256); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 256); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint16m4_t c = v_dot; + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s8 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s9 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s10 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s11 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s12 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s13 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s14 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s15 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + int32_t sums_arr[16] = { s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15 }; + + // Load 8 scale bytes and split into 16 nibbles + vuint8mf2_t v_sc8 = __riscv_vle8_v_u8mf2(scales, 8); + scales += 8; + + vuint8mf2_t v_lo8 = __riscv_vand_vx_u8mf2(v_sc8, 0x0F, 8); + vuint8mf2_t v_hi8 = __riscv_vsrl_vx_u8mf2(v_sc8, 4, 8); + + vuint8m1_t v_idx16 = __riscv_vid_v_u8m1(16); + vuint8m1_t v_half = __riscv_vsrl_vx_u8m1(v_idx16, 1, 16); + vbool8_t m_even = __riscv_vmseq_vx_u8m1_b8(__riscv_vand_vx_u8m1(v_idx16, 1, 16), 0, 16); + + vuint8m1_t v_lo_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_lo8); + vuint8m1_t v_hi_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_hi8); + vuint8m1_t v_lo_g = __riscv_vrgather_vv_u8m1(v_lo_ext, v_half, 16); + vuint8m1_t v_hi_g = __riscv_vrgather_vv_u8m1(v_hi_ext, v_half, 16); + vuint8m1_t v_nib = __riscv_vmerge_vvm_u8m1(v_lo_g, v_hi_g, m_even, 16); + + static const uint8_t iq2s_scale_lut_16_local[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + }; + vuint8m1_t v_lut = __riscv_vle8_v_u8m1(iq2s_scale_lut_16_local, 16); + vuint8m1_t v_sc8v = __riscv_vrgather_vv_u8m1(v_lut, v_nib, 16); + + vint32m4_t v_sums = __riscv_vle32_v_i32m4(sums_arr, 16); + vuint16m2_t v_sc16 = __riscv_vwcvtu_x_x_v_u16m2(v_sc8v, 16); + vuint32m4_t v_sc32u = __riscv_vwcvtu_x_x_v_u32m4(v_sc16, 16); + vint32m4_t v_sc32 = __riscv_vreinterpret_v_u32m4_i32m4(v_sc32u); + vint32m4_t v_prod = __riscv_vmul_vv_i32m4(v_sums, v_sc32, 16); + + vint32m1_t v_zero32 = __riscv_vmv_v_x_i32m1(0, 1); + int32_t sum_part = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(v_prod, v_zero32, 16)); + sum_block += sum_part; + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2898,8 +4219,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; default: - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -2907,7 +4231,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -3045,59 +4369,140 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT int32_t sum_int = 0; - // Loop over 4 subblocks of 64 elements (QK_K = 256) - for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { - // Load 8 uint16 indices (controls 64 values) - vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); - qs += 8; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16); + qs += 16; - // Extract indices for grid (low 9 bits) and signs (high 7 bits) - // Multiply by 8 (<< 3) for byte offsets into the uint64 tables - vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); - vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + // Prepare offsets for grid and signs + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16); - vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); - vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + // Indexed load 128 weights (16 x 8-byte chunks) + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16); - vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); - vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); - vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + // Apply signs to get dequantized IQ2 values + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128); + asm volatile("" ::: "memory"); - vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + // Load corresponding Q8 weights + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128); + asm volatile("" ::: "memory"); - vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); - int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); - int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); - int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); - int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + // 9. Reduce each 16-element chunk and apply corresponding nibble scale - const uint8_t scale_byte_1 = scales[0]; - const uint8_t scale_byte_2 = scales[1]; - scales += 2; + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16)); + sum_int += s0 * ((sc0 & 0x0F) * 2 + 1); - sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); - sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); - sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); - sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16)); + sum_int += s1 * ((sc0 >> 4) * 2 + 1); + + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16)); + sum_int += s2 * ((sc1 & 0x0F) * 2 + 1); + + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16)); + sum_int += s3 * ((sc1 >> 4) * 2 + 1); + + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16)); + sum_int += s4 * ((sc2 & 0x0F) * 2 + 1); + + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16)); + sum_int += s5 * ((sc2 >> 4) * 2 + 1); + + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16)); + sum_int += s6 * ((sc3 & 0x0F) * 2 + 1); + + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16)); + sum_int += s7 * ((sc3 >> 4) * 2 + 1); } - sumf += d * sum_int; + sumf += d * (float)sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // Load indices --- + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 32); + + // Extract low 9 bits and multiply by 8 (shift left 3) for byte offset into uint64 table + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 32), 3, 32); + + // Extract high 7 bits (shift right 9) and multiply by 8 (shift left 3) for byte offset + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 32); + + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q2_signed = __riscv_vmul_vv_i8m4(q2_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q2_signed, q8_all, 256); + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + +#pragma GCC unroll 1 + for (int j = 0; j < 8; ++j) { + uint8_t sc = scales[j]; + int16_t sc_lo = 2 * (sc & 0x0F) + 1; + int16_t sc_hi = 2 * (sc >> 4) + 1; + + vint32m1_t sum_v0 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32, 16), zero_vec, 16); + int32_t isum0 = __riscv_vmv_x_s_i32m1_i32(sum_v0); + + vint32m1_t sum_v1 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32 + 16, 16), zero_vec, 16); + int32_t isum1 = __riscv_vmv_x_s_i32m1_i32(sum_v1); + + sum += (float)isum0 * sc_lo + (float)isum1 * sc_hi; + } + + sumf += sum * combined_scale; } *s = 0.125f * sumf; } #endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3105,8 +4510,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3114,7 +4519,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3299,24 +4704,99 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + // Shift pattern {0,7,14,21} repeated 8 times for all 8 sub-blocks + uint8_t shift_arr[32] = { + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21 + }; + vuint8mf2_t v_shifts = __riscv_vle8_v_u8mf2(shift_arr, 32); + + // Gather pattern to broadcast the 8 sub-block scales across the 32 lookup slots + uint8_t gather_arr[32] = { + 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, + 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 + }; + vuint8mf2_t v_sign_gather_idx = __riscv_vle8_v_u8mf2(gather_arr, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // De-interleave all 8 Index/Scale pairs for the 8x32-element sub-blocks + vuint32mf2x2_t tuple = __riscv_vlseg2e32_v_u32mf2x2((const uint32_t*)q2_ptr, 8); + vuint32mf2_t v_ind32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 0); + vuint32mf2_t v_sc32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 1); + + vuint8mf2_t v_raw_q2 = __riscv_vreinterpret_v_u32mf2_u8mf2(v_ind32); + vuint16m1_t vidx_q2 = __riscv_vwcvtu_x_x_v_u16m1(v_raw_q2, 32); + vidx_q2 = __riscv_vsll_vx_u16m1(vidx_q2, 3, 32); + + vuint32m2_t v_s = __riscv_vrgatherei16_vv_u32m2(__riscv_vlmul_ext_v_u32mf2_u32m2(v_sc32), __riscv_vwcvtu_x_x_v_u16m1(v_sign_gather_idx,32), 32); + v_s = __riscv_vsrl_vv_u32m2(v_s, __riscv_vwcvtu_x_x_v_u32m2(__riscv_vwcvtu_x_x_v_u16m1(v_shifts,32),32), 32); + v_s = __riscv_vand_vx_u32m2(v_s, 127, 32); + vuint16m1_t vidx_s2 = __riscv_vsll_vx_u16m1(__riscv_vncvt_x_x_w_u16m1(v_s, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_q2, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_s2, 32); + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q8s_all = __riscv_vmul_vv_i8m4(q8_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q8s_all, q2_all, 256); + + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + for (int j = 0; j < 8; ++j) { + uint32_t s_p = __riscv_vmv_x_s_u32mf2_u32(__riscv_vslidedown_vx_u32mf2(v_sc32, j, 8)); + int16_t sc = 2 * ((s_p >> 28) & 0xF) + 1; + dot_all=__riscv_vslidedown_vx_i16m8(dot_all,j*32,32); + vint32m1_t sum_v = __riscv_vwredsum_vs_i16m8_i32m1(dot_all, zero_vec, 32); + int32_t isum = __riscv_vmv_x_s_i32m1_i32(sum_v); + sum += (float)isum * sc; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: // 256 and above + case 256: ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + default: // 512 and above + ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; } #else ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3506,19 +4986,108 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + // Generate Constants + vuint8mf2_t v_id_32 = __riscv_vid_v_u8mf2(32); + vuint8mf2_t v_qh_gather = __riscv_vsrl_vx_u8mf2(v_id_32, 3, 32); + vuint8mf2_t v_qh_shifts = __riscv_vand_vx_u8mf2(v_id_32, 7, 32); + vuint8m2_t v_id_128 = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather = __riscv_vsrl_vx_u8m2(v_id_128, 3, 128); // byte index + vuint8m2_t v_sign_shift_amts = __riscv_vand_vx_u8m2(v_id_128, 7, 128); // bit shift + vuint8m2_t v_one_128 = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_one_128, v_sign_shift_amts, 128); + vuint8m2_t v_scale_indices = __riscv_vsrl_vx_u8m2(v_id_128, 5, 128); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + for (int ib = 0; ib < 2; ++ib) { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 32); + qs += 32; + vuint8mf2_t v_qh_loaded = __riscv_vle8_v_u8mf2(qh, 4); + qh += 4; + vuint8mf2_t v_qh_expanded = __riscv_vrgather_vv_u8mf2(v_qh_loaded, v_qh_gather, 32); + v_qh_expanded = __riscv_vsrl_vv_u8mf2(v_qh_expanded, v_qh_shifts, 32); + v_qh_expanded = __riscv_vand_vx_u8mf2(v_qh_expanded, 1, 32); + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 32); // * 4 + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_expanded, 32); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 32); // * 256 * 4 + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 32); + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs, 16); + signs += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 128); + uint16_t sc_raw; + memcpy(&sc_raw, scales, 2); + scales += 2; // Advance 2 bytes + + uint8_t sc_unpacked[4]; + sc_unpacked[0] = (sc_raw & 0xF); + sc_unpacked[1] = (sc_raw >> 4) & 0xF; + sc_unpacked[2] = (sc_raw >> 8) & 0xF; + sc_unpacked[3] = (sc_raw >> 12) & 0xF; + + vuint8mf2_t v_sc_4 = __riscv_vle8_v_u8mf2(sc_unpacked, 4); + v_sc_4 = __riscv_vmul_vx_u8mf2(v_sc_4, 2, 4); + v_sc_4 = __riscv_vadd_vx_u8mf2(v_sc_4, 1, 4); + vuint8m2_t v_sc_4_expanded = __riscv_vlmul_ext_v_u8mf2_u8m2(v_sc_4); + vuint8m2_t v_scales_bcast = __riscv_vrgather_vv_u8m2(v_sc_4_expanded, v_scale_indices, 128); + vint16m4_t v_scales_i16 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwcvtu_x_x_v_u16m4(v_scales_bcast, 128)); + vint32m8_t v_weighted_sum = __riscv_vwmul_vv_i32m8(v_dot, v_scales_i16, 128); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + int32_t s_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m8_i32m1(v_weighted_sum, v_zero, 128)); + + sum_block += s_val; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} #endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: - ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3526,7 +5095,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3712,10 +5281,181 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.25f * sumf; } + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // generate constants for unpacking metadata words into sign indices + vuint32m1_t v_shifts; + { + vuint32m1_t v_base = __riscv_vid_v_u32m1(16); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_base, 3, 16); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 16); + } + + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_idx = __riscv_vid_v_u16mf2(16); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_idx, 2, 16); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 32); + q3_indices += 32; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 32); + vuint32m2_t v_q3_mag_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 32); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_mag_u32)); + vuint32m1_t v_aux = __riscv_vreinterpret_v_u8m1_u32m1(__riscv_vle8_v_u8m1(metadata, 16)); + metadata += 4 * sizeof(uint32_t); + + vuint32m1_t v_aux_expanded = __riscv_vrgatherei16_vv_u32m1(v_aux, v_gather_idx, 16); + + vuint32m1_t v_s_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 16), 127, 16); + vuint16mf2_t sign_byte_offset = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_s_raw, 16), 3, 16); + vuint64m2_t v_s_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_byte_offset, 16); + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_s_u64)); + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_signs, 128); + vint16m4_t prod = __riscv_vwmul_vv_i16m4(v_q3_signed, v_q8, 128); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + int32_t group0_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 32)); + int32_t group1_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 32)); + int32_t group2_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 32)); + int32_t group3_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 32)); + + vuint32m1_t v_scales_raw = __riscv_vsrl_vx_u32m1(v_aux, 28, 4); + vuint32m1_t v_scales = __riscv_vadd_vx_u32m1( + __riscv_vsll_vx_u32m1(v_scales_raw, 1, 4), + 1, 4); + int32_t scale0 = (int32_t)__riscv_vmv_x_s_u32m1_u32(v_scales); + int32_t scale1 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 1, 4)); + int32_t scale2 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 2, 4)); + int32_t scale3 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 3, 4)); + + block_sum += (float)(group0_sum * scale0 + group1_sum * scale1 + + group2_sum * scale2 + group3_sum * scale3); + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + vuint32m1_t v_shifts; + { + vuint32m1_t v_id = __riscv_vid_v_u32m1(32); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_id, 3, 32); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 32); + } + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_id_16 = __riscv_vid_v_u16mf2(32); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_id_16, 2, 32); + } + + float sumf = 0.0f; + uint32_t aux32[8]; // Buffer for block metadata + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + vuint8mf2_t v_q3_idx_raw = __riscv_vle8_v_u8mf2(q3_indices, 64); + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_raw, 4, 64); + + vuint32m2_t v_q3_grid_vals = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 64); + + vint8m2_t v_q3_mags = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_grid_vals)); + + memcpy(aux32, metadata, 8 * sizeof(uint32_t)); + vuint32m1_t v_aux_8 = __riscv_vle32_v_u32m1(aux32, 8); + + vuint32m1_t v_aux_32 = __riscv_vrgatherei16_vv_u32m1(v_aux_8, v_gather_idx, 32); + + vuint32m1_t v_sign_idx_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_32, v_shifts, 32), 127, 32); + + vuint16mf2_t v_sign_offsets = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_sign_idx_raw, 32), 3, 32); + + vuint64m2_t v_signs_u64 = __riscv_vluxei16_v_u64m2(signs64, v_sign_offsets, 32); + + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_signs_u64)); + + vint8m2_t v_q3_final = __riscv_vmul_vv_i8m2(v_q3_mags, v_signs, 256); + + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_final, 256); + float block_sum = 0.0f; + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m4_t v_accum = v_dot; + + for (int j = 0; j < 8; ++j) { + float scale = (float)(2 * (aux32[j] >> 28) + 1); + + vint32m1_t v_partial_sum = __riscv_vwredsum_vs_i16m4_i32m1(v_accum, v_zero, 32); + + int32_t partial_sum_i = __riscv_vmv_x_s_i32m1_i32(v_partial_sum); + block_sum += partial_sum_i * scale; + v_accum = __riscv_vslidedown_vx_i16m4(v_accum, 32, 32); + + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} #endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3723,8 +5463,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + case 512: + ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 1024 and above + ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3732,7 +5475,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3847,7 +5590,7 @@ static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3861,7 +5604,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4007,10 +5750,205 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32)); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256); + q8 += 256; + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128); + const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128); + const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi); + const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32)); + const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256); + q8 += 256; + + // Mask for processing 32 elements per prod register. + const vuint16m1_t p_index = __riscv_vid_v_u16m1(64); + const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64)); + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4018,6 +5956,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -4027,7 +5971,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4230,10 +6174,112 @@ static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m1_t suml1; + { + const int vl = 32; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 0, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 32, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 64, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 96, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 128, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16mf2_t suml2; + { + const int vl = 16; + vuint8mf4_t tq = __riscv_vle8_v_u8mf4(x[i].qs + 32, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(tq, 3 * 1, vl), 8, vl); + vuint16mf2_t tq1 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 3, vl), 3, vl), 8, vl); + vuint16mf2_t tq2 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 9, vl), 3, vl), 8, vl); + vuint16mf2_t tq3 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 27, vl), 3, vl), 8, vl); + vuint16mf2_t tq4 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 81, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 160, vl), vl); + vint16mf2_t q81 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 176, vl), vl); + vint16mf2_t q82 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 192, vl), vl); + vint16mf2_t q83 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 208, vl), vl); + vint16mf2_t q84 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 224, vl), vl); + + vint16mf2_t sum0 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + vint16mf2_t sum1 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq1, 1, vl)), q81, vl); + vint16mf2_t sum2 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq2, 1, vl)), q82, vl); + vint16mf2_t sum3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq3, 1, vl)), q83, vl); + vint16mf2_t sum4 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq4, 1, vl)), q84, vl); + + vint16mf2_t sumi0 = __riscv_vadd_vv_i16mf2(sum0, sum1, vl); + vint16mf2_t sumi1 = __riscv_vadd_vv_i16mf2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16mf2(sum4, __riscv_vadd_vv_i16mf2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16mf2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf4_t tq = __riscv_vlmul_trunc_v_u8mf2_u8mf4(__riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4))); + + vuint8mf4_t p = __riscv_vle8_v_u8mf4(pow, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vv_u8mf4(tq, p, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + } + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(suml1, __riscv_vmv_v_x_i32m1(0, 1), 32); + sum = __riscv_vwredsum_vs_i16mf2_i32m1(__riscv_vadd_vv_i16mf2(suml2, suml3, 16), sum, 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} #endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4241,8 +6287,8 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_tq1_0_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -4250,7 +6296,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -4406,24 +6452,21 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - case 256: + default: // 256 and above ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; } #else ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4538,7 +6581,7 @@ static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); From e9dbd0c18a1904b84c2b75b8bff81ff6ecb6c886 Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Wed, 3 Jun 2026 22:05:04 -0700 Subject: [PATCH 242/289] ggml-webgpu: FlashAttention refactor + standardize quantization support (llama/23834) * Start work on flash_attn refactor * Refactor * Split k/v quantization * Refactor and abstract quantization logic for flash_attn and mul_mat * Add quantization support to tile path * formatting * Move to functions, add a check --- ggml/src/ggml-webgpu/CMakeLists.txt | 7 +- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 659 +++++++++--------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 416 ++++++----- ggml/src/ggml-webgpu/pre_wgsl.hpp | 44 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 271 ++----- .../flash_attn_quant_staging.tmpl | 124 ++++ .../wgsl-shaders/flash_attn_tile.wgsl | 126 ++-- .../wgsl-shaders/flash_attn_vec_split.wgsl | 247 ++----- .../wgsl-shaders/mul_mat_decls.tmpl | 20 +- .../wgsl-shaders/quant_inner_loops.tmpl | 21 + 10 files changed, 985 insertions(+), 950 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 3ccce58aa39..1503a1ef8ba 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR}) message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}") -# Find all WGSL files -file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl") +# Find all WGSL sources +file(GLOB WGSL_SHADER_FILES + "${SHADER_DIR}/*.wgsl" + "${SHADER_DIR}/*.tmpl" +) # Generate the header using a Python script add_custom_command( diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f4c5eca0df5..a5e7de785b4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -18,6 +18,9 @@ #define GGML_WEBGPU_F32_SIZE_BYTES 4 #define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u +#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u @@ -546,16 +549,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ -enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, -}; - -struct ggml_webgpu_flash_attn_pipeline_key { +struct ggml_webgpu_flash_attn_common_pipeline_key { ggml_type q_type; - ggml_type kv_type; + ggml_type k_type; + ggml_type v_type; ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; @@ -564,93 +561,224 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t path; + + bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { + return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + } +}; + +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, + const ggml_webgpu_flash_attn_common_pipeline_key & key) { + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.k_type); + ggml_webgpu_hash_combine(seed, key.v_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); +} + +struct ggml_webgpu_flash_attn_vec_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; } +}; + +struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + bool use_sg_matrix; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return common == other.common && use_sg_matrix == other.use_sg_matrix; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.kv_overlap); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.path); + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + ggml_webgpu_hash_combine(seed, key.use_sg_matrix); return seed; } }; +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; - bool kv_overlap = false; + bool use_sg_matrix = false; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || - key.head_dim_qk != key.head_dim_v) { - return 1u; +inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { + const uint32_t offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); + return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); +} + +inline bool ggml_webgpu_flash_attn_kv_direct( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t kv_direct_align) { + ggml_webgpu_flash_attn_common_pipeline_key key = {}; + key.q_type = context.src0->type; + key.k_type = context.src1->type; + key.v_type = context.src2->type; + key.dst_type = context.dst->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; +} + +inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines( + const ggml_webgpu_flash_attn_common_pipeline_key & key, + std::string & variant, + uint32_t q_tile, + uint32_t kv_tile, + uint32_t wg_size) { + std::vector<std::string> defines; + + switch (key.k_type) { + case GGML_TYPE_F32: + defines.push_back("K_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("K_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("K_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("K_Q8_0"); + break; + default: + GGML_ABORT("Unsupported K type for flash attention shader"); + } + variant += std::string("_k") + ggml_type_name(key.k_type); + + switch (key.v_type) { + case GGML_TYPE_F32: + defines.push_back("V_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("V_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("V_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("V_Q8_0"); + break; + default: + GGML_ABORT("Unsupported V type for flash attention shader"); + } + variant += std::string("_v") + ggml_type_name(key.v_type); + + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); } + variant += std::string("_q") + ggml_type_name(key.q_type); - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - return 2u; - case 96: - return 4u; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; default: - return 1u; + GGML_ABORT("Unsupported dst type for flash attention shader"); } -} + variant += std::string("_dst") + ggml_type_name(key.dst_type); -inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_decisions & decisions) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - bool kv_direct = false; - if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { - kv_direct_align = context.sg_mat_k; - } - kv_direct = (context.src1->type == GGML_TYPE_F16) && - (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - } - - ggml_webgpu_flash_attn_pipeline_key key = {}; - key.q_type = context.src0->type; - key.kv_type = context.src1->type; - key.dst_type = context.dst->type; - key.head_dim_qk = (uint32_t) context.src0->ne[0]; - key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = kv_direct; - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = has_mask; - key.has_sinks = has_sinks; - key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = decisions.path; - return key; + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) { + defines.push_back("U32_DEQUANT_HELPERS"); + } + + return defines; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -688,29 +816,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { } }; -// This is exposed because it's necessary in supports_op +// Note: this will slightly overestimate memory usage for vec path +// since row_max and exp_sum shmem are not needed. inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct, - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - f32_elems += head_dim_qk; // q_shmem - if (!kv_direct) { - f32_elems += kv_tile * max_head_dim; // kv_shmem - } - f32_elems += head_dim_v; // o_shmem - if (has_mask) { - f32_elems += kv_tile; // mask_shmem - } - f32_elems += kv_tile; // inter_shmem - return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; - } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { f32_elems += kv_tile * max_head_dim; // kv_shmem @@ -725,25 +842,20 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = std::max(1u, context.sg_mat_n); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = 1u; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - q_tile = 1u; - kv_granularity = 8u; - } - const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, + uint32_t q_tile, + uint32_t kv_granularity, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const size_t base_q_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct); if (limit_bytes <= base_q_bytes) { return 0; } - const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); + const size_t one_kv_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct); const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; if (bytes_per_kv == 0) { return 0; @@ -752,105 +864,32 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( - const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - ggml_webgpu_flash_attn_decisions decisions = {}; - const size_t alignment = std::max<size_t>(1u, storage_offset_alignment); - const auto * K = context.src1; - const auto * V = context.src2; - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - - const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { - constexpr uintptr_t ptr_base_addr = 0x1000u; - const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; - return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; - }; - - const uint32_t k_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && - (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = - K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = - context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; - // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); - const bool tile_can_dispatch_all_q_rows = - context.max_subgroup_size > 0 && - context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && - context.src2->ne[0] % context.sg_mat_n == 0; - const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && - V->type == GGML_TYPE_F16 && f16_vec4_aligned && - (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows && !use_vec; - - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { - return decisions; - } - - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - // invalidate if even the smallest kv_tile doesn't fit in shared memory - if (max_kv_tile == 0) { - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - return decisions; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = context.max_subgroup_size; - if (decisions.kv_direct) { - decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= 8u; - } +inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_kv_tile = + ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; } - return decisions; } - decisions.q_tile = - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; - decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, max_kv_tile) : - std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(std::max(1u, context.max_wg_size), - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - if (decisions.kv_tile == 0) { - return decisions; - } + return kv_tile; +} - if (decisions.kv_direct) { - GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; - } - } - return decisions; +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { + return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; } /** Matrix Multiplication **/ @@ -1123,6 +1162,10 @@ class ggml_webgpu_shader_lib { concat_pipelines; // type std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash> repeat_pipelines; // type + std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_vec_pipeline_key_hash> + flash_attn_vec_pipelines; std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash> flash_attn_pipelines; std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key, @@ -1835,10 +1878,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -1971,11 +2014,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2148,10 +2191,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2271,10 +2314,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2664,119 +2707,62 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); - GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - auto it = flash_attn_pipelines.find(key); - if (it != flash_attn_pipelines.end()) { - return it->second; - } - std::vector<std::string> defines; - std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : - "flash_attn"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - switch (key.q_type) { - case GGML_TYPE_F32: - defines.push_back("Q_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("Q_F16"); - break; - default: - GGML_ABORT("Unsupported Q type for flash attention shader"); - } - variant += std::string("_q") + ggml_type_name(key.q_type); - - switch (key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - break; - default: - GGML_ABORT("Unsupported dst type for flash attention shader"); - } - variant += std::string("_dst") + ggml_type_name(key.dst_type); - - if (key.has_mask) { - defines.push_back("MASK"); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("BLK"); - variant += "_mask_blk"; - } else { - variant += "_mask"; + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + ggml_webgpu_flash_attn_decisions decisions = {}; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.common = + ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; + key.use_sg_matrix = decisions.use_sg_matrix; + + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u, + key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + decisions.kv_tile = decisions.use_sg_matrix ? + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = + decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + + if (key.common.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size; } } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - if (key.kv_overlap) { - defines.push_back("KV_OVERLAP"); - variant += "_kv_overlap"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } - const char * shader_src = wgsl_flash_attn; - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("KV_GRANULARITY=8"); - defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); - shader_src = wgsl_flash_attn_vec_split; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; + std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, + decisions.kv_tile, decisions.wg_size); + const char * shader_src = nullptr; + if (!key.use_sg_matrix) { shader_src = wgsl_flash_attn_tile; defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + std::to_string(context.max_subgroup_size); } else { + shader_src = wgsl_flash_attn; defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - - auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); - pipeline_decisions->kv_overlap = key.kv_overlap; - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); - + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); pipeline.context = pipeline_decisions; @@ -2784,6 +2770,55 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_vec_decisions decisions = {}; + decisions.kv_tile = + ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + decisions.wg_size = context.max_subgroup_size; + + std::string variant = "flash_attn_vec"; + std::vector<std::string> defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size); + if (key.common.has_mask) { + defines.push_back("BLK"); + variant.resize(variant.size() - (sizeof("_mask") - 1)); + variant += "_mask_blk"; + } + uint32_t vec_ne = 1u; + if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 && + key.common.head_dim_qk == key.common.head_dim_v) { + switch (key.common.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; key.kv_tile = kv_tile; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d577b5afa3c..c6cfb0bbbad 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1755,13 +1755,50 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +struct ggml_webgpu_flash_attn_op { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + std::vector<uint32_t> params; + std::vector<wgpu::BindGroupEntry> entries; + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + bool has_mask = false; + bool has_sinks = false; + bool kv_overlap = false; +}; + +static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = + K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool v_vec_type_supported = + V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; + + return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && + v_float_vec4_aligned; +} + +static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = ggml_get_op_params_f32(dst, 0); float max_bias = ggml_get_op_params_f32(dst, 1); float logit_softcap = ggml_get_op_params_f32(dst, 2); @@ -1772,47 +1809,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); - const bool kv_overlap = decisions->kv_overlap; - - uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - size_t kv_bind_offset = 0; - size_t kv_bind_size = 0; - if (kv_overlap) { + ggml_webgpu_flash_attn_op op = {}; + op.shader_lib_ctx.src0 = Q; + op.shader_lib_ctx.src1 = K; + op.shader_lib_ctx.src2 = V; + op.shader_lib_ctx.src3 = mask; + op.shader_lib_ctx.src4 = sinks; + op.shader_lib_ctx.dst = dst; + op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + + op.has_mask = mask != nullptr; + op.has_sinks = sinks != nullptr; + op.kv_overlap = ggml_webgpu_tensor_overlap(K, V); + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + if (op.kv_overlap) { const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); - kv_bind_offset = merged_range.offset; - kv_bind_size = merged_range.size; + op.kv_bind_offset = merged_range.offset; + op.kv_bind_size = merged_range.size; offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } - std::vector<uint32_t> params = { + op.params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), offset_k, offset_v, - has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, - has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) @@ -1826,7 +1859,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) ggml_webgpu_u32_from_f32(max_bias), @@ -1834,32 +1867,56 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_u32_from_f32(n_head_log2), ggml_webgpu_u32_from_f32(m0), ggml_webgpu_u32_from_f32(m1) - }; - std::vector<wgpu::BindGroupEntry> entries = { + op.entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; - if (kv_overlap) { - entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + if (op.kv_overlap) { + op.entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); } - uint32_t binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); + uint32_t binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } - if (has_sinks) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + if (op.has_sinks) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return op; +} + +static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) kv_tile; + while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) { + nwg <<= 1; } + return std::min(nwg, vec_nwg_cap); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + ggml_webgpu_flash_attn_op op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1868,13 +1925,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_batch_count = 0; const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -1910,7 +1962,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, webgpu_pipeline blk_pipeline; std::vector<uint32_t> blk_params; std::vector<wgpu::BindGroupEntry> blk_entries; - if (has_mask) { + if (op.has_mask) { blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); blk_nblk1 = (uint32_t) Q->ne[1]; blk_buf = ggml_webgpu_tensor_buf(dst); @@ -1918,7 +1970,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx; blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { @@ -1938,8 +1990,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); } - std::vector<uint32_t> split_params = params; - if (has_mask) { + std::vector<uint32_t> split_params = op.params; + if (op.has_mask) { split_params.push_back(0u); // blk_base split_params.push_back(blk_nblk0); // blk_nblk0 split_params.push_back(blk_nblk1); // blk_nblk1 @@ -1952,9 +2004,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), }; - if (kv_overlap) { + if (op.kv_overlap) { split_entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), @@ -1963,18 +2015,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_tensor_align_offset(ctx, V), ggml_webgpu_tensor_binding_size(ctx, V))); } - uint32_t split_binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { + uint32_t split_binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), ggml_webgpu_tensor_binding_size(ctx, mask))); } - if (has_sinks) { + if (op.has_sinks) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), ggml_webgpu_tensor_align_offset(ctx, sinks), ggml_webgpu_tensor_binding_size(ctx, sinks))); } - if (has_mask) { + if (op.has_mask) { split_entries.push_back( ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); } @@ -1993,7 +2045,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, reduce_sg_size, (uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -2020,7 +2072,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector<webgpu_dispatch_desc> dispatches; - if (has_mask) { + if (op.has_mask) { dispatches.push_back({ blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } }); @@ -2037,6 +2089,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst); + if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) { + return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op)); + } + return ggml_webgpu_flash_attn_direct(ctx, op); +} + static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -3553,70 +3619,43 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer break; case GGML_OP_FLASH_ATTN_EXT: { - const ggml_tensor * Q = tensor->src[0]; - const ggml_tensor * K = tensor->src[1]; - const ggml_tensor * V = tensor->src[2]; - const ggml_tensor * mask = tensor->src[3]; - const ggml_tensor * sinks = tensor->src[4]; - if (Q && K && V) { - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = const_cast<ggml_tensor *>(Q); - shader_lib_ctx.src1 = const_cast<ggml_tensor *>(K); - shader_lib_ctx.src2 = const_cast<ggml_tensor *>(V); - shader_lib_ctx.src3 = const_cast<ggml_tensor *>(mask); - shader_lib_ctx.src4 = const_cast<ggml_tensor *>(sinks); - shader_lib_ctx.dst = const_cast<ggml_tensor *>(tensor); - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t kv_tile = decisions.kv_tile; - - const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - - const size_t align = - ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - if (nwg > 1u) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( - (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += tmp_size_bytes + align; - } else { - res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; - } - if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = - ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += blk_size_bytes + align; - } - res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { + const bool kv_direct = + ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], + mask != nullptr, kv_direct); + + const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); + + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), + WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -4139,70 +4178,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { + // conservative support checks for whether the more resource-intensive shader paths + // can be used, to avoid cases where flash_attn is assigned to the CPU later on supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 || + src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) && + op->type == GGML_TYPE_F32; if (!supports_op) { break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.src2 = src2; - shader_lib_ctx.src3 = op->src[3]; - shader_lib_ctx.src4 = op->src[4]; - shader_lib_ctx.dst = const_cast<ggml_tensor *>(op); - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type && + !ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) { supports_op = false; break; } - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + + // subgroup matrix path requirements + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + + // tile path requirements + const bool float_vec4_aligned = + ((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && + ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); + const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = + src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; + const bool tile_can_dispatch_all_q_rows = + capabilities.limits.maxComputeInvocationsPerWorkgroup >= + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned && + tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; + + if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; } - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } + const uint32_t q_tile = + use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); + supports_op = max_kv_tile > 0; break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 4d4359463ca..fb41a961d74 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -37,15 +37,33 @@ static std::string trim(const std::string & s) { } static std::string trim_value(std::istream & is) { - std::string str; - std::getline(is, str); - return trim(str); + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { return std::isalnum(static_cast<unsigned char>(c)) || c == '_'; } +static bool endsWithContinuation(const std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + return i > 0 && line[i - 1] == '\\'; +} + +static void stripContinuation(std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } +} + static std::string expandMacrosRecursiveInternal(const std::string & line, const std::unordered_map<std::string, std::string> & macros, std::unordered_set<std::string> & visiting); @@ -595,19 +613,31 @@ class Preprocessor { std::string line; while (std::getline(in, line)) { - std::string t = trim(line); + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) { + break; + } + logical += "\n"; + logical += line; + } + t = trim(logical); + } if (!t.empty() && t[0] == '#') { bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); if (mode == DirectiveMode::IncludesOnly && !handled) { - out << line << "\n"; + out << logical << "\n"; } } else { if (mode == DirectiveMode::IncludesOnly) { - out << line << "\n"; + out << logical << "\n"; } else if (condActive(cond)) { // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(line, macros); + std::string expanded = expandMacrosRecursive(logical, macros); out << expanded << "\n"; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 6d5d69fb8de..9767ca3d754 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,12 +4,23 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -#ifdef KV_F32 -#define KV_TYPE f32 -#elif defined(KV_Q4_0) || defined(KV_Q8_0) -#define KV_TYPE u32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif // Default values @@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define BLOCK_SIZE_BYTES 18u -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define BLOCK_SIZE_BYTES 34u -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - -#if defined(KV_Q4_0) || defined(KV_Q8_0) -fn load_k_u16_at(byte_offset: u32) -> u32 { - let word = K[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_k_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = K[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = K[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn load_v_u16_at(byte_offset: u32) -> u32 { - let word = V[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_v_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = V[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = V[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn f16_from_u16(bits: u32) -> f16 { - let packed = unpack2x16float(bits); - return f16(packed[0]); -} -#endif - struct Params { offset_q: u32, offset_k: u32, @@ -139,11 +80,11 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<f32>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #endif #if defined(MASK) && defined(SINKS) @@ -238,10 +179,47 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32 return (*buf)[scalar_index >> 2u]; } -fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> { +fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> { return (*buf)[scalar_index >> 2u]; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -311,77 +289,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); // clear inter_shmem to ensure zero-initialized accumulators for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -520,71 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - kv_shmem[elem_idx] = f16(select( - 0.0, - V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl new file mode 100644 index 00000000000..8f41eb7bfdb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -0,0 +1,124 @@ +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) + +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u +#elif defined(K_Q8_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u +#elif defined(V_Q8_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; + let q_packed = load_k_u32_at(q_byte_offset); +#if defined(K_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#elif defined(K_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; + let q_packed = load_v_u32_at(q_byte_offset); +#if defined(V_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#elif defined(V_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 4133f0ab564..e68934113fc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,16 +1,29 @@ enable f16; enable subgroups; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef Q_F16 #define Q_TYPE f16 #else #define Q_TYPE f32 #endif -#ifdef KV_F32 -#define KV_TYPE f32 +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif #ifdef DST_F16 @@ -21,7 +34,6 @@ enable subgroups; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -#define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 @@ -64,11 +76,23 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#else +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; +#endif #endif #if defined(MASK) && defined(SINKS) @@ -121,10 +145,50 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>; -var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>; -var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>; +var<workgroup> kv_shmem: array<f16, kv_shmem_size>; +var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>; + +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); + } +} +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @@ -206,18 +270,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, local_scores[slot] = FLOAT_MIN; } - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / Q_CHUNKS; - let chunk = vec_idx_local % Q_CHUNKS; - let global_k_row = kv_tile + kv_local; - let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; - let k4 = K[k_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); +#endif workgroupBarrier(); @@ -238,8 +293,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4<KV_TYPE>( + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + let kv = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], @@ -271,25 +326,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); + p_shmem[subgroup_p_offset + kv_local] = f16(p); local_sum += p; } } workgroupBarrier(); - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / V_CHUNKS; - let chunk = vec_idx_local % V_CHUNKS; - let global_v_row = kv_tile + kv_local; - let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; - let v4 = V[v_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); +#endif workgroupBarrier(); @@ -306,14 +352,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, var acc = out_regs[reg_idx]; for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { - let p = p_shmem[subgroup_p_offset + kv_local]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4<KV_TYPE>( + let p = f32(p_shmem[subgroup_p_offset + kv_local]); + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + let v4 = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += f32(p) * vec4<f32>(v4); + acc += p * vec4<f32>(v4); } out_regs[reg_idx] = acc; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 30ebbebe772..30ed97cca0c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -2,10 +2,23 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -#ifdef KV_F32 -#define KV_TYPE f32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 #endif #ifdef Q_F16 @@ -32,28 +45,6 @@ enable subgroups; #define KV_BLOCKS (KV_TILE / KV_GRANULARITY) -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -#if defined(KV_Q4_0) -#define NQ 16 -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - struct Params { offset_q: u32, offset_k: u32, @@ -103,22 +94,22 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif #define V K #else -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #else -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; #endif #endif #if defined(MASK) && defined(SINKS) @@ -244,6 +235,49 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) return v; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f32 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -308,6 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); #ifdef BLK let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; @@ -324,76 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; - let vec_idx = (global_k_row_offset + k_col) >> 2u; - let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(k4.x); - kv_shmem[elem_idx + 1u] = f32(k4.y); - kv_shmem[elem_idx + 2u] = f32(k4.z); - kv_shmem[elem_idx + 3u] = f32(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -510,76 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; - let vec_idx = (global_v_row_offset + v_col) >> 2u; - let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(v4.x); - kv_shmem[elem_idx + 1u] = f32(v4.y); - kv_shmem[elem_idx + 2u] = f32(v4.z); - kv_shmem[elem_idx + 3u] = f32(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb2a8368f43..72991504dd0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) { } #endif // SCALAR +#define QUANT_SHMEM shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" + #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -124,14 +128,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; - } + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } @@ -314,12 +311,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte_i32(q_packed, k); - - let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; - } + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl new file mode 100644 index 00000000000..d1da4608434 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl @@ -0,0 +1,21 @@ +#ifdef U32_DEQUANT_HELPERS +fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + QUANT_SHMEM[dst_idx + k] = q_lo; + QUANT_SHMEM[dst_idx + k + 16u] = q_hi; + } +} + +fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = QUANT_OUT_TYPE(q_byte) * scale; + QUANT_SHMEM[dst_idx + k] = q_val; + } +} +#endif From 9d6e561f692b9b6353a33fa63e8b8a6998a41cb1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 4 Jun 2026 08:05:32 +0300 Subject: [PATCH 243/289] metal : reduce rset heartbeat from 500ms -> 5ms (llama/24074) --- ggml/src/ggml-metal/ggml-metal-device.m | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 196af102643..05d7f43051b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -547,6 +547,8 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { // number of seconds since the last graph computation // keep the residency sets wired for that amount of time to avoid being collected by the OS int keep_alive_s; + int loops_per_s; + int time_per_loop_ms; // background heartbeat thread to keep the residency sets alive atomic_bool d_stop; @@ -573,10 +575,13 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { res->keep_alive_s = 3*60; } + res->time_per_loop_ms = 5; + res->loops_per_s = 1000/res->time_per_loop_ms; + GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); - atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&res->d_loop, res->loops_per_s*res->keep_alive_s, memory_order_relaxed); res->d_group = dispatch_group_create(); @@ -599,8 +604,7 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { [res->lock unlock]; } - // half a second - usleep(500 * 1000); + usleep(res->time_per_loop_ms * 1000); } } #endif @@ -979,7 +983,7 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { return; } - atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&dev->rsets->d_loop, dev->rsets->loops_per_s*dev->rsets->keep_alive_s, memory_order_relaxed); } struct ggml_metal_event { From 991b5a8b4ab652b0bc282f500a58de565e7aa0bc Mon Sep 17 00:00:00 2001 From: Kartik Sirohi <99896785+sirohikartik@users.noreply.github.com> Date: Thu, 4 Jun 2026 18:42:38 +0530 Subject: [PATCH 244/289] ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 (llama/22209) * ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 Optimize the inner loop of ggml_vec_dot_q4_1_q8_1_generic using WASM SIMD128 intrinsics, gated behind #ifdef __wasm_simd128__ so non-wasm builds are completely unaffected. Approach: - single wasm_v128_load covers all 32 packed 4-bit weights - nibbles unpacked via AND/SHR into two u8x16 registers - widened to i16 before multiply (WASM SIMD has no i8*i8 instruction) - 4x wasm_i32x4_dot_i16x8 calls accumulate all 32 element pairs - horizontal reduce via 4x wasm_i32x4_extract_lane Benchmark (node v25, emcc -O3 -msimd128, 64 blocks x QK8_1=32, 200k iterations): | impl | ns/call | speedup | |--------|---------|---------| | scalar | 880.7 | 1.00x | | simd | 257.8 | 3.42x | Correctness verified against scalar reference across 10 random seeds with exact output match. * ggml: move q4_1_q8_1 WASM SIMD implementation to wasm backend Relocate the SIMD128 implementation of ggml_vec_dot_q4_1_q8_1 to ggml/src/ggml-cpu/arch/wasm/quants.c to follow architecture-specific layout. Restore the generic implementation in ggml/src/ggml-cpu/quants.c. Move for loop in the else block. * ggml: use generic q4_1_q8_1 fallback in wasm backend --- ggml/src/ggml-cpu/arch/wasm/quants.c | 72 ++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 648c6fcaba7..0a7119b4e1f 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; } +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + float summs = 0.0f; + + for (int ib = 0; ib < nb; ++ib) { + const block_q4_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + + const v128_t raw = wasm_v128_load(x0->qs); + const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F)); + const v128_t v1s = wasm_u8x16_shr(raw, 4); + + const v128_t ys_lo = wasm_v128_load(y0->qs); + const v128_t ys_hi = wasm_v128_load(y0->qs + 16); + + const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s); + const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s); + const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo); + const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo); + const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s); + const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s); + const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi); + const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi); + + const v128_t acc = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v0s_l, ylo_l), + wasm_i32x4_dot_i16x8(v0s_h, ylo_h)), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v1s_l, yhi_l), + wasm_i32x4_dot_i16x8(v1s_h, yhi_h))); + + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul( + wasm_f32x4_convert_i32x4(acc), + wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + + *s = sumf; + +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(sumf); + + ggml_vec_dot_q4_1_q8_1_generic( + n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; From 4ecede8c8bb6b4d899a13e27b8b672ad1bc67311 Mon Sep 17 00:00:00 2001 From: Mason Milburn <masonmilby@gmail.com> Date: Fri, 5 Jun 2026 01:10:31 -0400 Subject: [PATCH 245/289] sycl : port multi-column MMVQ from CUDA backend (llama/21845) mmvq: Port the ncols_dst optimization from ggml-cuda/mmvq.cu to SYCL. Read weights once per dispatch instead of once per column. Covers all standard quant types + reorder paths for Q4_0, Q8_0, Q3_K, Q4_K, Q5_K, Q6_K. IQ types (except IQ4_XS) excluded due to incompatible vec_dot signatures. ggml-sycl: The weight reorder was only bootstrapped on single-token mat-vec (ne[1] == 1). Speculative / MTP verify issues only multi-column mat-vec, so it never triggered the reorder and ran on the slower non-reorder kernel. Bootstrap it on small multi-column batches (ne[1] <= 8) too. --- ggml/src/ggml-sycl/ggml-sycl.cpp | 4 +- ggml/src/ggml-sycl/mmvq.cpp | 1118 +++++++++++++++++++++++++++++- 2 files changed, 1095 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 96138f57ebe..3f246e8672d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3971,7 +3971,9 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; + // ne[1] <= 8 so multi-column decode (spec / MTP verify) also bootstraps the reorder; + // all reorderable types have a _switch_ncols kernel. + dst->src[1]->ne[1] <= 8 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; } static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index abd1e49a70e..cf2b59576aa 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r } } +template <typename reorder_vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy, + float * __restrict__ dst, const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum[ncols_dst] = {0.0f}; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const char * vy_j = (const char *)vy + j * stride_col_y_bytes; + const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2)); + + partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>()); + + if (sg.leader()) { + dst[j * stride_col_dst + row] = sum; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) { @@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } } +template <int qk, int qi, typename block_q_t, int vdr, + vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_ncols( + const void * __restrict__ vx, + const void * __restrict__ vy, + float * __restrict__ dst, + const int ncols, + const int nrows, + const int stride_col_y, + const int stride_col_dst, + const sycl::nd_item<3> & item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + // partial sums: one per output column + float tmp[ncols_dst] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); + i < blocks_per_row; + i += blocks_per_warp) { + + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + // read weight block once, dot against all columns + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs); + } + } + } + + // reduce within subgroup +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp[j] += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), tmp[j], mask); + } + } + + if (item_ct1.get_local_id(2) == 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + dst[j * stride_col_dst + row] = tmp[j]; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr> static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, const void *__restrict__ vy, @@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); @@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * } } +template <int ncols_dst> +static void mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0, + VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1, + VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_MXFP4 == 0); @@ -613,6 +853,45 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_mxfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_MXFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4, + VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_NVFP4 == 0); @@ -631,6 +910,45 @@ static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_nvfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4, + VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -655,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0, + VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -679,6 +1036,45 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1, + VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); @@ -698,6 +1094,45 @@ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -722,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0, + VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -746,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q2_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K, + VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -790,6 +1303,85 @@ static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst); + } +} + +template <int ncols_dst> +static void mul_mat_vec_q3_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K, + VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst); + } +} + + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -814,6 +1406,51 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K, + VDR_Q4_K_Q8_1_MMVQ, + vec_dot_q4_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -834,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst); + } +} static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -859,6 +1534,51 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K, + VDR_Q5_K_Q8_1_MMVQ, + vec_dot_q5_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -879,6 +1599,45 @@ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -897,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, }); }); } + +template <int ncols_dst> +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -921,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q6_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K, + VDR_Q6_K_Q8_1_MMVQ, + vec_dot_q6_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1117,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs, + 1, + vec_dot_iq4_xs_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst); + } +} + void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, @@ -1143,42 +2032,135 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q8_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q3_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, - stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1186,9 +2168,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1196,9 +2196,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q5_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1206,9 +2224,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1238,13 +2274,43 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_MXFP4: - mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_NVFP4: - mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; default: GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); From 4fa1e0687e23bfaf286ba328c2c6d0d592cb3158 Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Fri, 5 Jun 2026 08:37:34 +0200 Subject: [PATCH 246/289] CUDA: enroll mul_mat_vec_q_moe into pdl (llama/24087) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enroll mul_mat_vec_q_moe into PDL, boosting MTP performance on BW Data collected on a B4500: Before ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=202.8 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=212.8 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=196.4 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=226.6 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=225.1 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=201.5 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=197.2 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=209.2 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=208.9 ``` After ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=211.9 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=224.6 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=207.8 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=240.2 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=238.5 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=213.4 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=208.8 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=221.7 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=220.7 ``` Server launched with: ``` ➜ llama.cpp git:(osimons/enroll_mul_mat_vec_q_moe_into_PDL) ✗ ./build-x64-linux-gcc-reldbg/bin/llama-server \ -m /mnt/share/gguf/unsloth/Qwen3.6-35B-A3B-MTP-GGUF/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf -dio \ --spec-type draft-mtp \ --spec-draft-n-max 2 \ -ngl all \ -fa on \ --host 0.0.0.0 \ --port 8080 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" ``` * LC to overlap with following kernels --- ggml/src/ggml-cuda/mmvq.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 4b0426590ac..bdfbfd2d387 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -682,12 +682,16 @@ static __global__ void mul_mat_vec_q( template <ggml_type type, int c_rows_per_block> __launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q_moe( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, - float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, + float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint32_t ncols_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; @@ -707,6 +711,7 @@ static __global__ void mul_mat_vec_q_moe( return; } + ggml_cuda_pdl_sync(); const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); @@ -726,6 +731,8 @@ static __global__ void mul_mat_vec_q_moe( } } + ggml_cuda_pdl_lc(); + // Warp-level reduction only - no shared memory needed #pragma unroll for (int i = 0; i < c_rows_per_block; ++i) { @@ -794,8 +801,9 @@ static void mul_mat_vec_q_moe_launch( const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; const dim3 block_nums(nblocks_rows, nchannels_dst); const dim3 block_dims(warp_size, ncols_dst); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); - mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>( + ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params, vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, stride_row_x, stride_col_y, stride_col_dst, stride_channel_x, stride_channel_y, stride_channel_dst, From facb02c4c3a32f07935c5da60e92b0f650f1bd40 Mon Sep 17 00:00:00 2001 From: Charles Xu <charles.xu@arm.com> Date: Fri, 5 Jun 2026 09:11:47 +0200 Subject: [PATCH 247/289] kleidiai : dynamic chunck-based scheduling for hybrid execution (llama/23819) --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 272 ++++++++++++------------ 1 file changed, 141 insertions(+), 131 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 0ecf7ae02ac..9e54b676b93 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -38,6 +38,7 @@ #include "kleidiai.h" #include "ggml-cpu.h" +#include "ggml-cpu-impl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" @@ -61,7 +62,8 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels_q8; int sme_thread_cap; // <= 0 means “SME disabled/unknown”; int thread_hint; // <= 0 means “no hint” -} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 }; + int chunk_multiplier; +} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 }; static const char* cpu_feature_to_string(cpu_feature f) { if (f == CPU_FEATURE_NONE) { @@ -186,8 +188,9 @@ static void init_kleidiai_context(void) { if (!initialized) { initialized = true; - const char *env_sme = getenv("GGML_KLEIDIAI_SME"); - const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_sme = getenv("GGML_KLEIDIAI_SME"); + const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER"); const bool cpu_has_sme = ggml_cpu_has_sme(); size_t detected_smcus = 0; @@ -204,6 +207,14 @@ static void init_kleidiai_context(void) { } } + if (env_chunk_mult) { + bool ok = false; + int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok); + if (ok && multiplier > 0) { + ctx.chunk_multiplier = multiplier; + } + } + // SME policy: // - If CPU doesn't support SME: SME always off. // - Else: @@ -296,6 +307,50 @@ static inline size_t align_up(size_t value, size_t alignment) { return remainder == 0 ? value : value + (alignment - remainder); } +static inline size_t gcd_size(size_t a, size_t b) { + while (b != 0) { + const size_t t = a % b; + a = b; + b = t; + } + return a; +} + +static inline bool lcm_size(size_t a, size_t b, size_t & result) { + if (a == 0 || b == 0) { + result = 0; + return false; + } + const size_t g = gcd_size(a, b); + const size_t q = a / g; + if (q > SIZE_MAX / b) { + return false; + } + result = q * b; + return true; +} + +static inline size_t ceil_div_size(size_t a, size_t b) { + return b == 0 ? 0 : (a + b - 1) / b; +} + +struct kleidiai_block_args { + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; +}; + +static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) { + switch (rhs_type) { + case GGML_TYPE_Q4_0: + return { QK4_0, QK4_0, QK4_0 }; + case GGML_TYPE_Q8_0: + return { 0, 0, QK8_0 }; + default: + return { 0, 0, 0 }; + } +} + static inline bool kleidiai_pack_fallback_allowed() { if (ctx.sme_thread_cap <= 0) { return false; @@ -746,8 +801,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t n_step; size_t lhs_packed_size; size_t lhs_offset; - size_t n_offset; - size_t n_cols; + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; + size_t lhs_packed_offset0; int assigned_threads; int thread_begin; int thread_end; @@ -772,6 +829,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { continue; } + const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type); + runtime[runtime_count] = { slot, kernels, @@ -784,7 +843,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { kinfo->get_n_step(), 0, 0, - 0, + block_args.lhs_bl, + block_args.rhs_bl, + block_args.pack_bl, 0, 0, 0, @@ -795,45 +856,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { } if (runtime_count == 0) { - ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!fallback) { - return false; - } - kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm; - lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info; - rhs_packing_info * rinfo = &fallback->rhs_info; - if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || - !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset || - !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) { - return false; - } - kernel_chain[0] = fallback; - runtime[0] = { - 0, - fallback, - kinfo, - linfo, - kinfo->get_mr(), - kinfo->get_nr(), - kinfo->get_kr(), - kinfo->get_sr(), - kinfo->get_n_step(), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - nullptr - }; - size_t rhs_size_fallback = 0; - const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback); - if (!rhs_base) { - rhs_base = static_cast<const uint8_t *>(src0->data); - } - runtime[0].rhs_base = rhs_base; - runtime_count = 1; + GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name); + return false; } const int nth_total = params->nth > 0 ? params->nth : 1; @@ -846,6 +870,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { break; } } + int non_sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) { + non_sme_slot = i; + break; + } + } const int sme_cap_limit = ctx.sme_thread_cap; const bool use_hybrid = sme_cap_limit > 0 && @@ -864,12 +895,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (!hybrid_enabled) { int chosen_slot = 0; if (too_small_for_hybrid && sme_slot != -1) { - chosen_slot = sme_slot; + chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot; } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) { chosen_slot = 1; } if (chosen_slot != 0 && chosen_slot < runtime_count) { runtime[0] = runtime[chosen_slot]; + runtime[0].assigned_threads = 0; + runtime[0].thread_begin = 0; + runtime[0].thread_end = 0; } runtime_count = runtime_count > 0 ? 1 : 0; @@ -896,6 +930,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; int fallback_count = 0; + // The current hybrid chain is bounded to SME + one non-SME fallback slot. + GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2); for (int i = 0; i < runtime_count; ++i) { if (i == sme_slot) { continue; @@ -952,73 +988,67 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t cursor = 0; for (int i = 0; i < runtime_count; ++i) { - const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type; - const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; - runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr); + runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); runtime[i].lhs_offset = cursor; + runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor += runtime[i].lhs_packed_size; } GGML_ASSERT(cursor <= params->wsize); uint8_t * scratch = static_cast<uint8_t *>(params->wdata); - size_t assigned_cols = 0; - uint64_t weighted_total = 0; - if (runtime_count > 1 && sme_slot != -1) { - for (int i = 0; i < runtime_count; ++i) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - weighted_total += (uint64_t)runtime[i].assigned_threads * weight; - } - } + size_t common_step = 1; for (int i = 0; i < runtime_count; ++i) { - runtime[i].n_offset = assigned_cols; if (runtime[i].assigned_threads == 0) { - runtime[i].n_cols = 0; continue; } - const size_t remaining_cols = n - assigned_cols; - if (remaining_cols == 0) { - runtime[i].n_cols = 0; - continue; - } - const size_t step = runtime[i].n_step ? runtime[i].n_step : 1; - size_t target = 0; - if (weighted_total > 0) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total); - } else { - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total); - } - target = std::min(target, remaining_cols); - size_t aligned = round_down(target, step); - if (aligned == 0 && remaining_cols >= step) { - aligned = step; + size_t next_step = 0; + if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) { + return false; } - runtime[i].n_cols = aligned; - assigned_cols += aligned; + common_step = next_step; } - - if (assigned_cols < n) { - for (int i = runtime_count - 1; i >= 0; --i) { - if (runtime[i].assigned_threads > 0) { - runtime[i].n_cols += n - assigned_cols; - break; - } - } + GGML_ASSERT(common_step > 0); + + const bool disable_chunking = ggml_is_numa(); + const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier); + const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier; + size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step); + if (chunk_cols == 0) { + chunk_cols = common_step; } + // If common_step is larger than n, the loop below runs one valid tail chunk + // with cols == n. + const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols)); + GGML_ASSERT(nchunk_size <= (size_t)INT_MAX); + const int nchunk = (int)nchunk_size; const size_t dst_stride = dst->nb[1]; + auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) { + const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl); + const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); + + const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0; + const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; + float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); + + slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + }; + for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) { const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2]; uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2]; if (runtime[local_slot].assigned_threads > 0) { runtime_slot & slot = runtime[local_slot]; - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr); int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads; max_threads = std::max<int64_t>(1, max_threads); @@ -1031,8 +1061,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_start = (int64_t)local_ith * num_m_per_thread0; const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); + const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0; int64_t remaining = m_count; @@ -1049,7 +1079,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); + slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); cur += take; remaining -= take; @@ -1057,49 +1087,29 @@ class tensor_traits : public ggml::cpu::tensor_traits { } } + if (ith_total == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth_total); + } + + // Publishes both LHS packing and the initialized dynamic chunk queue. ggml_barrier(params->threadpool); runtime_slot & slot = runtime[local_slot]; - if (slot.n_cols > 0 && slot.assigned_threads > 0) { - int64_t active_threads = slot.assigned_threads; - const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads; - if (max_threads > 0) { - active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads)); + int current_chunk = ith_total; + while (current_chunk < nchunk) { + const size_t global_start = (size_t)current_chunk * chunk_cols; + if (global_start >= n) { + break; } - active_threads = std::max<int64_t>(1, active_threads); - - if (local_ith < active_threads) { - const size_t step = slot.n_step ? slot.n_step : 1; - const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step); - const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0; - const size_t local_start = (size_t)local_ith * chunk0; - const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0; - - if (cols > 0) { - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t global_start = slot.n_offset + local_start; - const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg); - const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); - - const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset; - const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; - float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); - - slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg, - lhs_ptr, - rhs_ptr, - dst_ptr, - dst_stride, - sizeof(float), - -FLT_MAX, - FLT_MAX); - } + + const size_t cols = std::min(chunk_cols, n - global_start); + if (cols > 0) { + // KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths; + // only non-tail chunks are guaranteed to be n_step-aligned. + run_chunk(slot, global_start, cols, dst_batch_base); } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } if (batch_idx != ne12 - 1) { From 5a1feed8ca57b70d002ca0df2abd3db9328a1daa Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Fri, 5 Jun 2026 19:44:40 +0200 Subject: [PATCH 248/289] vulkan: add fwht support for Intel with shmem reduction (llama/23964) * vulkan: add fwht support for Intel with shmem reduction * don't use N as workgroup size * disable subgroup shuffle on MoltenVK AMD * disable fwht shader on Intel Windows due to driver bug --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 78 +++++++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e7d04634b8a..df410368a79 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5084,6 +5084,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ++idx; } + } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) { + // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147 + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + const uint32_t block_size = std::min(device->subgroup_size, n); + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1); + ++idx; + } } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; @@ -5630,6 +5638,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); +#ifdef __APPLE__ + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_shuffle = false; + } +#endif device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp index 72059d4afc2..a2069964adb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -1,14 +1,16 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#ifndef FWHT_SHMEM #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_shuffle : enable +#endif -layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; - -layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint N = 128; +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + layout(push_constant) uniform parameter { uint n_rows; @@ -20,35 +22,72 @@ layout(push_constant) uniform parameter layout(binding = 0, std430) readonly buffer A { float data_a[]; }; layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; -const uint EL_W = N / WARP_SIZE; +const uint EL_W = N / BLOCK_SIZE; + +#ifdef FWHT_SHMEM +shared float shmem[4 * N]; +#endif void main() { - const uint lane = gl_SubgroupInvocationID; - for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; - row < n_rows; - row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { +#ifdef FWHT_SHMEM + const uint tid = gl_LocalInvocationID.x; + const uint shmem_base = gl_LocalInvocationID.y * N; + const uint row_id = gl_LocalInvocationID.y; +#else + const uint tid = gl_SubgroupInvocationID; + const uint row_id = gl_SubgroupID; +#endif + + for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y; + base_row < n_rows; + base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row = base_row + row_id; const uint row_offset = row * N; +#ifndef FWHT_SHMEM + if (row >= n_rows) { + continue; + } +#endif + float reg[EL_W]; [[unroll]] for (uint i = 0; i < EL_W; ++i) { - reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0; } +#ifdef FWHT_SHMEM + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i]; + } + barrier(); + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)]; + reg[j] = (tid & h) == 0 ? val + other : other - val; + } + barrier(); + } +#else [[unroll]] - for (uint h = 1; h < WARP_SIZE; h <<= 1) { + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { [[unroll]] for (uint j = 0; j < EL_W; ++j) { const float val = reg[j]; const float val2 = subgroupShuffleXor(val, h); - reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + reg[j] = (tid & h) == 0 ? val + val2 : val2 - val; } } +#endif [[unroll]] - for (uint h = WARP_SIZE; h < N; h <<= 1) { - const uint step = h / WARP_SIZE; + for (uint h = BLOCK_SIZE; h < N; h <<= 1) { + const uint step = h / BLOCK_SIZE; [[unroll]] for (uint j = 0; j < EL_W; j += 2 * step) { [[unroll]] @@ -61,9 +100,16 @@ void main() { } } - [[unroll]] - for (uint i = 0; i < EL_W; ++i) { - data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; +#ifdef FWHT_SHMEM + if (row < n_rows) { +#endif + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i]; + } +#ifdef FWHT_SHMEM } + barrier(); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index de7dbec2c63..d65cd12b287 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -957,6 +957,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("fwht_f32", "fwht.comp", {}); + string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From a87e950a0634481140473324e346badd458ff11f Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Fri, 5 Jun 2026 13:45:25 -0700 Subject: [PATCH 249/289] opencl: improve get_rows, cpy, concat and q6_k flat gemv (llama/24160) * opencl: allow multiple workgroups for large rows * opencl: improve small cpy * opencl: packed concat for small input * opencl: tweak flat q6_K gemv, increase N_DST and remap threads --- ggml/src/ggml-opencl/ggml-opencl.cpp | 71 +++++++++-- ggml/src/ggml-opencl/kernels/concat.cl | 67 +++++++++++ ggml/src/ggml-opencl/kernels/cpy.cl | 59 +++++++++ ggml/src/ggml-opencl/kernels/get_rows.cl | 24 ++-- .../kernels/mul_mv_q6_k_f32_flat.cl | 112 ++++++++---------- 5 files changed, 247 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c411e4aeaec..2a41215fd13 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -558,7 +558,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_f32_f32_pack, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -639,7 +639,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32; + cl_kernel kernel_concat_f32, kernel_concat_f32_pack; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; @@ -1121,6 +1121,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32_pack = clCreateKernel(prog, "kernel_cpy_f32_f32_pack", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -2615,6 +2616,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_pack = clCreateKernel(prog, "kernel_concat_f32_pack", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -8552,7 +8554,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c nth *= 2; } - size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12}; + int nchunks = 1; + if (src0->type == GGML_TYPE_F32) { + const int chunk_target = nth * 4; + nchunks = (ne00 + chunk_target - 1) / chunk_target; + nchunks = MAX(1, MIN(nchunks, 64)); + } + + size_t global_work_size[] = {(size_t)ne10*nth*nchunks, (size_t)ne11, (size_t)ne12}; size_t local_work_size[] = {(size_t)nth, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -11128,7 +11137,9 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con int nth = MIN(64, ne0); - cl_kernel kernel = backend_ctx->kernel_concat_f32; + const bool concat_pack = (dim == 0 && ne0 < 32); + cl_kernel kernel = concat_pack ? backend_ctx->kernel_concat_f32_pack + : backend_ctx->kernel_concat_f32; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -11155,10 +11166,28 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); - size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + if (concat_pack) { + // packed kernel needs the dst dims to unflatten its 1-D row index. + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &ne3)); + + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne0, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; + const int nrows = ne1*ne2*ne3; + const int nwg = (nrows + rpw - 1) / rpw; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + } else { + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -14536,7 +14565,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; nth1 = 2; - ndst = 4; + ndst = 16; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -16633,7 +16662,8 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_cpy_f32_f16; break; case GGML_TYPE_F32: - kernel = backend_ctx->kernel_cpy_f32_f32; + kernel = ne00 < 32 ? backend_ctx->kernel_cpy_f32_f32_pack + : backend_ctx->kernel_cpy_f32_f32; break; default: GGML_ASSERT(false && "not implemented"); @@ -16685,12 +16715,27 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - const int nth = MIN(64, ne00); + if (kernel == backend_ctx->kernel_cpy_f32_f32_pack) { + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne00, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; // <= base <= maxwg + const int nrows = ne01*ne02*ne03; + const int nwg = (nrows + rpw - 1) / rpw; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, src1); + } else { + const int nth = MIN(64, ne00); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + } } static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 0c1b3d785ca..2fbd7851d3d 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -49,3 +49,70 @@ kernel void kernel_concat_f32( *y = *x; } } + +kernel void kernel_concat_f32_pack( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim, + int ne1, + int ne2, + int ne3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int lsz = get_local_size(0); + int tpr = min(ne0, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne1*ne2*ne3; + if (row >= nrows) { + return; + } + + int i1 = row % ne1; + int t = row / ne1; + int i2 = t % ne2; + int i3 = t / ne2; + + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + for (int i0 = lane; i0 < ne0; i0 += tpr) { + global const float * x; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 820aa538a34..adbd2e766d2 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -183,6 +183,65 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_cpy_f32_f32_pack( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int lsz = get_local_size(0); + int tpr = min(ne00, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne01*ne02*ne03; + if (row >= nrows) { + return; + } + + int i01 = row % ne01; + int t = row / ne01; + int i02 = t % ne02; + int i03 = t / ne02; + + // linear index of the first element of this row, unflattened over dst dims + long n = (long)row * ne00; + int i3 = (int)(n / ((long)ne2*ne1*ne0)); + long rm = n - (long)i3*ne2*ne1*ne0; + int i2 = (int)(rm / ((long)ne1*ne0)); + rm -= (long)i2*ne1*ne0; + int i1 = (int)(rm / ne0); + int i0 = (int)(rm - (long)i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = lane; i00 < ne00; i00 += tpr) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_i32_i32( global int * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index c2962edc983..9ae4fff09fc 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -82,21 +82,27 @@ kernel void kernel_get_rows_f32( src1 = (global int*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - int i10 = get_group_id(0); - int i11 = get_group_id(1); - int i12 = get_group_id(2); + int nchunks = get_num_groups(0) / ne10; + int g = get_group_id(0); + int i10 = g / nchunks; + int chunk = g - i10 * nchunks; + int i11 = get_group_id(1); + int i12 = get_group_id(2); int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; int i03 = i12; - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - if (ind >= ne00) { - return; - } - ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; + global float * dst_row = (global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1); + global float * src_row = (global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03); + + int span = (ne00 + nchunks - 1) / nchunks; + int start = chunk * span; + int end = min(start + span, ne00); + + for (int ind = start + get_local_id(0); ind < end; ind += get_local_size(0)) { + dst_row[ind] = src_row[ind]; } } diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl index 86fe09c6dd6..57b90c05ae5 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -33,13 +33,15 @@ inline float block_q_6_K_dot_y_flat( global uchar * blk_qh, global char * blk_scales, global half * blk_d, - global float * yy, int ib, int ip, int is, - int l0 + int l0, + float4 y0, + float4 y1, + float4 y2, + float4 y3 ) { - int y_offset = 128*ip + l0; int q_offset_l = 64*ip + l0; int q_offset_h = 32*ip + l0; @@ -48,36 +50,28 @@ inline float block_q_6_K_dot_y_flat( global uchar * qh = blk_qh + ib*64 + q_offset_h; global char * sc = blk_scales + ib*16 + is; - global float * y = yy + ib * QK_K + y_offset; - float dall = blk_d[ib]; - float sumf = 0; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - - return sumf; + // Vectorized loads: 3 uchar4 weight loads instead of 12 scalar byte reads. + // q_offset_l/h are 4-aligned, so these are aligned vector loads. + uchar4 q1v = vload4(0, q1); + uchar4 q2v = vload4(0, q2); + uchar4 qhv = vload4(0, qh); + + int4 q1i = convert_int4(q1v); + int4 q2i = convert_int4(q2v); + int4 qhi = convert_int4(qhv); + + // Reconstruct the four 6-bit weight groups (low/high nibble of ql OR'd with the + // matching 2-bit plane of qh), same arithmetic as the scalar version, then dot() + // against the cached activation lanes. + float4 w0 = convert_float4((q1i & 0xF) | ((qhi & Q6_K_MASK1) << 4)) - 32.f; + float4 w1 = convert_float4((q2i & 0xF) | ((qhi & Q6_K_MASK2) << 2)) - 32.f; + float4 w2 = convert_float4((q1i >> 4) | ((qhi & Q6_K_MASK3) )) - 32.f; + float4 w3 = convert_float4((q2i >> 4) | ((qhi & Q6_K_MASK4) >> 2)) - 32.f; + + return dall * (dot(y0, w0) * sc[0] + dot(y1, w1) * sc[2] + + dot(y2, w2) * sc[4] + dot(y3, w3) * sc[6]); } #undef N_DST @@ -89,7 +83,7 @@ inline float block_q_6_K_dot_y_flat( #define N_SIMDGROUP 2 #define N_SIMDWIDTH 16 #elif defined (ADRENO_GPU) -#define N_DST 4 +#define N_DST 16 #define N_SIMDGROUP 2 #define N_SIMDWIDTH 64 #endif @@ -146,49 +140,39 @@ kernel void kernel_mul_mv_q6_K_f32_flat( global half * blk_d = (global half *) src0_d + offset_src0_d; global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int tid = get_sub_group_local_id()%(N_SIMDWIDTH/BLOCK_STRIDE); // within-super-block part, 0..15 + int ix = get_sub_group_local_id()/(N_SIMDWIDTH/BLOCK_STRIDE); // super-block selector, 0..BLOCK_STRIDE-1 int ip = tid/8; // first or second half of (super) block (0 or 1) int il = tid%8; // each half has 8 parts, one per scale int n = 4; // 4 scales at a time (and 4 sums) int l0 = n*il; // offset into half-block, 0..28 int is = 8*ip + l0/16; // 0, 1, 8, 9 - float4 sumf = 0; + float sumf[N_DST]; + for (int row = 0; row < N_DST; row++) { + sumf[row] = 0.f; + } for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { - if (first_row + 0 < ne01) { - sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0); - } - if (first_row + 1 < ne01) { - sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0); - } - if (first_row + 2 < ne01) { - sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0); - } - if (first_row + 3 < ne01) { - sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0); + global float * y = yy + ib * QK_K + 128*ip + l0; + float4 y0 = vload4(0, y + 0); + float4 y1 = vload4(0, y + 32); + float4 y2 = vload4(0, y + 64); + float4 y3 = vload4(0, y + 96); + + for (int row = 0; row < N_DST; row++) { + if (first_row + row < ne01) { + sumf[row] += block_q_6_K_dot_y_flat( + blk_ql + row*nb*128, blk_qh + row*nb*64, blk_scales + row*nb*16, blk_d + row*nb, + ib, ip, is, l0, y0, y1, y2, y3); + } } } - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), - sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), - sub_group_reduce_add(sumf.s3) - ); - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + for (int row = 0; row < N_DST; row++) { + float tot = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } From 1777deff4c014d4bfc895eafe24354deeec80c94 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Sat, 6 Jun 2026 09:11:35 +0200 Subject: [PATCH 250/289] vulkan: check coopmat2 features before reporting support (llama/24186) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index df410368a79..fc9bc8fe376 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6349,6 +6349,15 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + } +#endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; if (coopmat2_decode_vector_support) { @@ -6380,6 +6389,19 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + coopmat2_support = coopmat2_support && + coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads; +#else + coopmat2_support = false; +#endif + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; #if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) coopmat2_decode_vector_support = false; From 2c139c2e5ee7a0820fdcf5e1a308322027453eac Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen <son@huggingface.co> Date: Mon, 8 Jun 2026 08:03:18 +0200 Subject: [PATCH 251/289] metal : fix im2col 1D case (audio models) (llama/24220) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 5d4b10d34b9..ce847dd8b6f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + char base[256]; char name[256]; - if (ne00*ne01 <= 1024) { + if (KH*KW <= 1024) { snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); } else { snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); From 4669631d20f32342ae58c346224eeabde251fa07 Mon Sep 17 00:00:00 2001 From: Harkirat Gill <harkirat.gill@amd.com> Date: Mon, 8 Jun 2026 02:33:23 -0400 Subject: [PATCH 252/289] HIP: add gfx1152 and gfx1153 to RDNA3.5 (llama/24129) --- ggml/src/ggml-cuda/vendors/hip.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5e0e22c7fc2..a6115cd80dc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -219,9 +219,9 @@ #define RDNA3 #endif // defined(__GFX11__) -#if defined(__gfx1150__) || defined(__gfx1151__) +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #define RDNA3_5 -#endif // defined(__gfx1150__) || defined(__gfx1151__) +#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #if defined(RDNA3) && !defined(RDNA3_5) #define RDNA3_0 From b932ec55298a6ebd2bcd2a0d2e62f9df61e6008d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:52:17 +0300 Subject: [PATCH 253/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 538ef80bc7a..f42565bab0f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3 +c95cd071e1bf235ac41ef58c5a5535f73024375c From b31466b4a13d4d55d1bbd9a6055861a8bc3968de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:51:59 +0300 Subject: [PATCH 254/289] ggml : bump version to 0.14.0 (ggml/1533) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index dc8899b46ef..8f7cb8cdfd2 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 13) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 14) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 4df9a57df23a1eb5a47ce988d606b81e7dc0db27 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:52:27 +0300 Subject: [PATCH 255/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index f42565bab0f..6e1bf3a1f4b 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c95cd071e1bf235ac41ef58c5a5535f73024375c +7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1 From 84bd03a438454a82150853dce83818013c6609d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:55:06 +0300 Subject: [PATCH 256/289] talk-llama : sync llama.cpp --- examples/talk-llama/CMakeLists.txt | 1 + examples/talk-llama/llama-adapter.cpp | 8 +- examples/talk-llama/llama-arch.cpp | 13 + examples/talk-llama/llama-arch.h | 10 + examples/talk-llama/llama-context.cpp | 224 ++++---- examples/talk-llama/llama-context.h | 17 +- examples/talk-llama/llama-cparams.h | 9 +- examples/talk-llama/llama-ext.h | 14 +- examples/talk-llama/llama-graph.cpp | 322 ++++++++--- examples/talk-llama/llama-graph.h | 109 +++- examples/talk-llama/llama-hparams.cpp | 88 +-- examples/talk-llama/llama-hparams.h | 67 ++- examples/talk-llama/llama-impl.h | 14 + examples/talk-llama/llama-kv-cache-dsa.cpp | 261 +++++++++ examples/talk-llama/llama-kv-cache-dsa.h | 138 +++++ examples/talk-llama/llama-kv-cache-iswa.cpp | 22 +- examples/talk-llama/llama-kv-cache-iswa.h | 4 +- examples/talk-llama/llama-kv-cache.cpp | 222 ++++++-- examples/talk-llama/llama-kv-cache.h | 15 +- examples/talk-llama/llama-kv-cells.h | 2 + .../talk-llama/llama-memory-hybrid-iswa.cpp | 6 +- examples/talk-llama/llama-memory-hybrid.cpp | 7 +- .../talk-llama/llama-memory-recurrent.cpp | 8 +- examples/talk-llama/llama-memory.h | 4 + examples/talk-llama/llama-model-loader.cpp | 15 +- examples/talk-llama/llama-model-saver.cpp | 16 +- examples/talk-llama/llama-model.cpp | 335 ++++++++---- examples/talk-llama/llama-model.h | 14 +- examples/talk-llama/llama-quant.cpp | 4 +- examples/talk-llama/llama-vocab.cpp | 95 +++- examples/talk-llama/llama-vocab.h | 112 ++-- examples/talk-llama/llama.cpp | 9 +- examples/talk-llama/llama.h | 11 +- examples/talk-llama/models/afmoe.cpp | 2 +- examples/talk-llama/models/apertus.cpp | 11 +- examples/talk-llama/models/arcee.cpp | 2 +- examples/talk-llama/models/arctic.cpp | 2 +- examples/talk-llama/models/arwkv7.cpp | 2 +- examples/talk-llama/models/baichuan.cpp | 2 +- examples/talk-llama/models/bailingmoe.cpp | 2 +- examples/talk-llama/models/bailingmoe2.cpp | 21 +- examples/talk-llama/models/bert.cpp | 4 +- examples/talk-llama/models/bitnet.cpp | 2 +- examples/talk-llama/models/bloom.cpp | 2 +- examples/talk-llama/models/chameleon.cpp | 2 +- examples/talk-llama/models/chatglm.cpp | 3 +- examples/talk-llama/models/codeshell.cpp | 3 +- examples/talk-llama/models/cogvlm.cpp | 3 +- examples/talk-llama/models/cohere2.cpp | 4 +- examples/talk-llama/models/command-r.cpp | 3 +- examples/talk-llama/models/dbrx.cpp | 12 +- examples/talk-llama/models/deci.cpp | 3 +- examples/talk-llama/models/deepseek2.cpp | 11 +- examples/talk-llama/models/deepseek2ocr.cpp | 2 +- examples/talk-llama/models/deepseek32.cpp | 499 ++++++++++++++++++ examples/talk-llama/models/dots1.cpp | 3 +- examples/talk-llama/models/dream.cpp | 3 +- examples/talk-llama/models/ernie4-5.cpp | 2 +- examples/talk-llama/models/eurobert.cpp | 2 +- examples/talk-llama/models/exaone-moe.cpp | 22 +- examples/talk-llama/models/exaone.cpp | 2 +- examples/talk-llama/models/exaone4.cpp | 45 +- examples/talk-llama/models/falcon-h1.cpp | 4 +- examples/talk-llama/models/falcon.cpp | 2 +- .../talk-llama/models/gemma-embedding.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n.cpp | 6 +- .../talk-llama/models/gemma4-assistant.cpp | 200 +++++++ examples/talk-llama/models/gemma4.cpp | 59 ++- examples/talk-llama/models/glm-dsa.cpp | 17 +- examples/talk-llama/models/glm4-moe.cpp | 24 +- examples/talk-llama/models/glm4.cpp | 20 +- examples/talk-llama/models/gpt2.cpp | 3 +- examples/talk-llama/models/gptneox.cpp | 3 +- examples/talk-llama/models/granite-hybrid.cpp | 8 +- examples/talk-llama/models/granite-moe.cpp | 2 +- examples/talk-llama/models/granite.cpp | 39 +- examples/talk-llama/models/grok.cpp | 2 +- examples/talk-llama/models/grovemoe.cpp | 2 +- examples/talk-llama/models/hunyuan-moe.cpp | 2 +- examples/talk-llama/models/internlm2.cpp | 3 +- examples/talk-llama/models/jais.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 2 +- examples/talk-llama/models/jamba.cpp | 6 +- examples/talk-llama/models/jina-bert-v2.cpp | 2 +- examples/talk-llama/models/jina-bert-v3.cpp | 2 +- examples/talk-llama/models/kimi-linear.cpp | 10 +- examples/talk-llama/models/lfm2.cpp | 20 +- examples/talk-llama/models/lfm2moe.cpp | 8 +- examples/talk-llama/models/llada-moe.cpp | 5 +- examples/talk-llama/models/llada.cpp | 4 +- examples/talk-llama/models/llama.cpp | 4 +- examples/talk-llama/models/llama4.cpp | 5 +- examples/talk-llama/models/maincoder.cpp | 3 +- examples/talk-llama/models/mamba.cpp | 2 +- examples/talk-llama/models/mamba2.cpp | 2 +- examples/talk-llama/models/mellum.cpp | 225 ++++++++ examples/talk-llama/models/mimo2.cpp | 23 +- examples/talk-llama/models/minicpm.cpp | 4 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 2 +- examples/talk-llama/models/models.h | 42 ++ examples/talk-llama/models/modern-bert.cpp | 13 +- examples/talk-llama/models/mpt.cpp | 2 +- examples/talk-llama/models/nemotron-h.cpp | 10 +- examples/talk-llama/models/nemotron.cpp | 3 +- examples/talk-llama/models/neo-bert.cpp | 2 +- examples/talk-llama/models/nomic-bert-moe.cpp | 2 +- examples/talk-llama/models/nomic-bert.cpp | 2 +- examples/talk-llama/models/olmo.cpp | 2 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 3 +- examples/talk-llama/models/openai-moe.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 12 +- examples/talk-llama/models/orion.cpp | 2 +- examples/talk-llama/models/pangu-embed.cpp | 3 +- examples/talk-llama/models/phi2.cpp | 2 +- examples/talk-llama/models/phi3.cpp | 2 +- examples/talk-llama/models/phimoe.cpp | 2 +- examples/talk-llama/models/plamo.cpp | 2 +- examples/talk-llama/models/plamo2.cpp | 10 +- examples/talk-llama/models/plamo3.cpp | 2 +- examples/talk-llama/models/plm.cpp | 3 +- examples/talk-llama/models/qwen.cpp | 2 +- examples/talk-llama/models/qwen2.cpp | 3 +- examples/talk-llama/models/qwen2moe.cpp | 3 +- examples/talk-llama/models/qwen3.cpp | 3 +- examples/talk-llama/models/qwen35.cpp | 88 +-- examples/talk-llama/models/qwen35moe.cpp | 87 +-- examples/talk-llama/models/qwen3moe.cpp | 6 +- examples/talk-llama/models/qwen3next.cpp | 12 +- examples/talk-llama/models/qwen3vl.cpp | 3 +- examples/talk-llama/models/qwen3vlmoe.cpp | 3 +- examples/talk-llama/models/refact.cpp | 3 +- examples/talk-llama/models/rnd1.cpp | 5 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv6qwen2.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 3 +- examples/talk-llama/models/smallthinker.cpp | 4 +- examples/talk-llama/models/smollm3.cpp | 2 +- examples/talk-llama/models/stablelm.cpp | 2 +- examples/talk-llama/models/starcoder.cpp | 3 +- examples/talk-llama/models/starcoder2.cpp | 3 +- examples/talk-llama/models/step35.cpp | 314 ++++++++++- examples/talk-llama/models/t5.cpp | 4 +- examples/talk-llama/models/talkie.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 3 +- 151 files changed, 3434 insertions(+), 866 deletions(-) create mode 100644 examples/talk-llama/llama-kv-cache-dsa.cpp create mode 100644 examples/talk-llama/llama-kv-cache-dsa.h create mode 100644 examples/talk-llama/models/deepseek32.cpp create mode 100644 examples/talk-llama/models/gemma4-assistant.cpp create mode 100644 examples/talk-llama/models/mellum.cpp diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 1adeef8f511..13b284ed0e9 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -20,6 +20,7 @@ if (WHISPER_SDL2) llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-dsa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp llama-memory-hybrid-iswa.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 4a1aaa955a8..3e0fe66afff 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -41,7 +41,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_size =*/ hparams.n_layer()*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -61,9 +61,9 @@ bool llama_adapter_cvec::init(const llama_model & model) { }; // make tensors - tensors.reserve(hparams.n_layer); + tensors.reserve(hparams.n_layer()); tensors.push_back(nullptr); // there's never a tensor for layer 0 - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -121,7 +121,7 @@ bool llama_adapter_cvec::apply( layer_start = il_start; layer_end = il_end; - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index e95ba6daac1..6a5d5f8d2ac 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, + { LLM_ARCH_DEEPSEEK32, "deepseek32" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -134,6 +136,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -194,6 +197,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_DEEPSTACK_MAPPING, "%s.deepstack_mapping" }, + { LLM_KV_HIDDEN_ACT, "%s.hidden_activation" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -244,6 +249,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, + { LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -318,12 +324,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -446,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -758,6 +768,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. @@ -904,6 +916,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) { case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_GLM_DSA: case LLM_ARCH_BITNET: case LLM_ARCH_T5: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 7c1dcc4d6c2..03b1a265d67 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -61,6 +61,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -79,6 +80,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2OCR, + LLM_ARCH_DEEPSEEK32, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -138,6 +140,7 @@ enum llm_arch { LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, LLM_ARCH_UNKNOWN, }; @@ -198,6 +201,8 @@ enum llm_kv { LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, + LLM_KV_DEEPSTACK_MAPPING, + LLM_KV_HIDDEN_ACT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -248,6 +253,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ATTENTION_SHARED_KV_LAYERS, + LLM_KV_ATTENTION_RECURRENT_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -307,12 +313,14 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_PAD_ID, LLM_KV_TOKENIZER_FIM_REP_ID, LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_SUPPRESS_TOKENS, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, @@ -550,6 +558,8 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index ad36c06667d..9a40c4366af 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -58,19 +58,21 @@ llama_context::llama_context( cparams.n_rs_seq = 0; } - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; - cparams.embeddings_pre_norm = false; - cparams.embeddings_pre_norm_masked = false; - cparams.offload_kqv = params.offload_kqv; - cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; - cparams.warmup = false; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; + cparams.embeddings = params.embeddings; + cparams.embeddings_nextn = false; + cparams.embeddings_nextn_masked = false; + cparams.offload_kqv = params.offload_kqv; + cparams.no_perf = params.no_perf; + cparams.warmup = false; + + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -83,7 +85,17 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.ctx_type = params.ctx_type; + cparams.ctx_other = nullptr; + + // TODO: more generic + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + // TODO: change from runtime_error to llama_exception to avoid printing error message + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later @@ -182,6 +194,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -227,6 +241,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); + LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -296,10 +311,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), }; memory.reset(model.create_memory(params_mem, cparams)); @@ -337,7 +353,7 @@ llama_context::llama_context( // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.n_gpu_layers() > model.hparams.n_layer && + model.n_gpu_layers() > model.hparams.n_layer_all && model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -531,7 +547,7 @@ void llama_context::sched_reserve() { // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, - // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + // the ggml_mul_mat assertion fails. const uint32_t n_tokens_ch = 16*n_seqs; auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { @@ -577,16 +593,18 @@ void llama_context::sched_reserve() { int n_splits_tg = -1; int n_nodes_tg = -1; + const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max); + // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); if (!gf) { if (cparams.pipeline_parallel) { LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); cparams.pipeline_parallel = false; sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get()); } if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -614,7 +632,7 @@ void llama_context::sched_reserve() { // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -774,7 +792,9 @@ bool llama_context::memory_update(bool optimize) { const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -882,34 +902,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -float * llama_context::get_embeddings_pre_norm() { +float * llama_context::get_embeddings_nextn() { output_reorder(); - return embd_pre_norm.data; + return embd_nextn.data; } -float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { +float * llama_context::get_embeddings_nextn_ith(int32_t i) { output_reorder(); try { - if (embd_pre_norm.data == nullptr) { - throw std::runtime_error("no pre-norm embeddings"); + if (embd_nextn.data == nullptr) { + throw std::runtime_error("no nextn embeddings"); } - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); - if (!cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm rows are stored densely, indexed by raw token position. - if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { - throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + if (!cparams.embeddings_nextn_masked) { + // unmasked: nextn rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd)); } - return embd_pre_norm.data + (size_t) i * n_embd; + return embd_nextn.data + (size_t) i * n_embd; } const int64_t j = output_resolve_row(i); - return embd_pre_norm.data + j*n_embd; + return embd_nextn.data + j*n_embd; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); + LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); #else @@ -1098,11 +1118,11 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value, bool masked) { +void llama_context::set_embeddings_nextn(bool value, bool masked) { LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; - cparams.embeddings_pre_norm_masked = masked; + cparams.embeddings_nextn = value; + cparams.embeddings_nextn_masked = masked; } void llama_context::set_causal_attn(bool value) { @@ -1319,7 +1339,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1392,9 +1412,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1460,14 +1480,14 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) - if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + // extract nextn embeddings (hidden state before the final output norm) + if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + const uint32_t n_embd = hparams.n_embd_out(); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float)); } // TODO: hacky solution @@ -1622,7 +1642,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s } int llama_context::decode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1822,9 +1842,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1905,22 +1925,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) + // extract nextn embeddings before // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. { - const bool masked = cparams.embeddings_pre_norm_masked; + const bool masked = cparams.embeddings_nextn_masked; const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; - if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + const uint32_t n_embd = hparams.n_embd_out(); + float * embd_nextn_out = embd_nextn.data + offset*n_embd; - GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float)); } } @@ -2009,12 +2029,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; - bool has_embd_pre_norm = cparams.embeddings_pre_norm; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_nextn = cparams.embeddings_nextn; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -2026,14 +2045,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0; - if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm row exists for every token in the batch, not just + if (has_embd_nextn && !cparams.embeddings_nextn_masked) { + // unmasked: nextn row exists for every token in the batch, not just // those flagged via batch.logits[i] -> size by token count instead. - embd_pre_norm.size = (size_t) n_embd * n_batch; + embd_nextn.size = (size_t) n_embd_out * n_batch; } // Allocate backend sampling output buffers if there are backend samplers configured. @@ -2050,7 +2069,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -2067,7 +2086,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; - embd_pre_norm.data = nullptr; + embd_nextn.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -2096,8 +2115,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; offset += embd.size * sizeof(float); - embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0}; - offset += embd_pre_norm.size * sizeof(float); + embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; + offset += embd_nextn.size * sizeof(float); if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; @@ -2140,6 +2159,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { this->n_outputs = 0; + GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max); + return n_outputs_max; } @@ -2163,9 +2184,9 @@ void llama_context::output_reorder() { } } - if (embd_pre_norm.size > 0) { + if (embd_nextn.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]); } } @@ -2226,8 +2247,6 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::max(n_outputs, n_tokens); - LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -2343,7 +2362,7 @@ llm_graph_cb llama_context::graph_get_cb() const { // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer_all; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -3337,6 +3356,7 @@ llama_context_params llama_context_default_params() { /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, /*.n_rs_seq =*/ 0, + /*.n_outputs_max =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3366,6 +3386,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -3403,15 +3424,11 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); @@ -3422,7 +3439,7 @@ llama_context * llama_init_from_model( if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); @@ -3444,12 +3461,11 @@ llama_context * llama_init_from_model( } if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && - model->hparams.nextn_predict_layers == 0) { + model->hparams.n_layer_nextn == 0) { LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); return nullptr; } - try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3584,20 +3600,28 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { - ctx->set_embeddings_pre_norm(value, masked); +void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_nextn(value, masked); } -float * llama_get_embeddings_pre_norm(llama_context * ctx) { +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + if (!ctx) { + return nullptr; + } + + return ctx->get_memory(); +} + +float * llama_get_embeddings_nextn(llama_context * ctx) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm(); + return ctx->get_embeddings_nextn(); } -float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { +float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm_ith(i); + return ctx->get_embeddings_nextn_ith(i); } bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { @@ -3651,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { - auto * memory = ctx->get_memory(); + auto memory = ctx->get_memory(); llama_memory_context_ptr mctx; if (memory) { mctx = memory->init_full(); @@ -3691,10 +3715,6 @@ int32_t llama_set_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; @@ -4005,3 +4025,7 @@ void llama_opt_epoch( llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { return ctx->memory_breakdown(); } + +llama_context * llama_get_ctx_other(struct llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index d03f681d4a1..6f8f59a22a3 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -6,6 +6,7 @@ #include "llama-graph.h" #include "llama-adapter.h" #include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -84,8 +85,8 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); - float * get_embeddings_pre_norm(); - float * get_embeddings_pre_norm_ith(int32_t i); + float * get_embeddings_nextn(); + float * get_embeddings_nextn_ith(int32_t i); llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -110,7 +111,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value, bool masked); + void set_embeddings_nextn(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -273,7 +274,7 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr<llama_memory_i> memory; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view<float> logits = {nullptr, 0}; @@ -282,10 +283,10 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view<float> embd = {nullptr, 0}; - // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) - // populated only when cparams.embeddings_pre_norm is enabled and the model graph - // sets llm_graph_result::t_h_pre_norm - buffer_view<float> embd_pre_norm = {nullptr, 0}; + // hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_nextn is enabled and the model graph + // sets llm_graph_result::t_h_nextn + buffer_view<float> embd_nextn = {nullptr, 0}; struct sampling_info { // !samplers.empty() to check if any samplers are active diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 20ec59fe335..8a35d389ef4 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -13,6 +13,7 @@ struct llama_cparams { uint32_t n_ubatch; uint32_t n_seq_max; uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback + uint32_t n_outputs_max; // max outputs supported by the context int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -28,8 +29,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm - bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 + bool embeddings_nextn; // also extract the hidden state before the final output norm + bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -38,7 +39,7 @@ struct llama_cparams { bool fused_gdn_ch; // use fused gated delta net (chunked) bool auto_fgdn; bool no_perf; - bool warmup; + bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP] bool op_offload; bool kv_unified; bool pipeline_parallel; @@ -48,4 +49,6 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_other; }; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index edfa71c207c..bd74544129b 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -89,18 +89,16 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); -// -// pre-norm embeddings (hidden state before the final output norm) -// - -// Set whether the context outputs pre-norm embeddings or not +// Set whether the context outputs nextn embeddings or not // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); +LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); -LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fc027de8b39..da7a9295561 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -29,7 +30,10 @@ static ggml_tensor * build_attn_inp_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + // flash attention requires an f16 mask + const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(res); ggml_set_name(res, "attn_inp_kq_mask"); @@ -102,6 +106,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) { + const int64_t n_tokens = ubatch->n_tokens; + + if (ubatch->token) { + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } else { + // note: mtmd embedding input goes through here + GGML_ASSERT(ubatch->embd); + GGML_ASSERT(n_embd == embd->ne[0]); + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } + + // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states + // for now, we assume that the hidden state is always provided as an embedding + // ref: https://github.com/ggml-org/llama.cpp/pull/23643 + if (ubatch->embd) { + GGML_ASSERT(n_embd == h->ne[0]); + + ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } +} + +bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens); + + return res; +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -348,7 +385,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +template <typename T> +static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -359,7 +397,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -372,7 +410,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { LLAMA_LOG_DEBUG(" %2d ", i); for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { - float val = data[i * n_kv + j]; + float val = llama_cast<float>(data[i * n_kv + j]); if (val == -INFINITY) { LLAMA_LOG_DEBUG(" ∞"); } else { @@ -387,7 +425,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) { + using T = std::remove_reference_t<decltype(*data)>; + std::fill(data, data + ne, llama_cast<T>(-INFINITY)); + for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_pos p1 = ubatch->pos[i1]; @@ -413,38 +454,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { continue; } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f); } } - }; - - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); - - fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); if (debug) { - print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + print_mask(data, n_tokens, n_kv, n_swa, swa_type); } + }; + + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + if (self_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + } else { + fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - - float * data = (float *) self_kq_mask_swa->data; - - std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - - fill_mask(data, hparams.n_swa, hparams.swa_type); - - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (self_kq_mask_swa->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + } else { + fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); } } } @@ -499,23 +532,51 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) { + mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch); + + mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch); + + mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_rot(self_k_rot_lid); +} + +bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens; + res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams); + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { // base tensors may not be allocated if there are no non-SWA attention layers if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); } @@ -544,18 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs && self_k_idxs->buffer) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + return res; } @@ -568,23 +629,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) cross_kq_mask->data; - - for (int i = 0; i < n_tokens; ++i) { - GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + const auto fill_mask = [&](auto * data) { + using T = std::remove_reference_t<decltype(*data)>; + for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; + } } - } - data[i*n_enc + j] = f; + data[i*n_enc + j] = llama_cast<T>(f); + } } + }; + + if (cross_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) cross_kq_mask->data); + } else { + fill_mask((float *) cross_kq_mask->data); } } @@ -688,7 +756,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + } + if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) { attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); } @@ -696,7 +766,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + } + if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } @@ -742,18 +814,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; @@ -861,8 +933,8 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } - if (t_h_pre_norm != nullptr) { - ggml_set_output(t_h_pre_norm); + if (t_h_nextn != nullptr) { + ggml_set_output(t_h_nextn); } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { @@ -937,7 +1009,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cparams (params.cparams), ubatch (params.ubatch), n_embd (hparams.n_embd), - n_layer (hparams.n_layer), + n_layer (hparams.n_layer()), + n_layer_nextn (hparams.n_layer_nextn), n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), @@ -1791,7 +1864,12 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->t_inp_embd = cur; // For Granite architecture - if (hparams.f_embedding_scale != 0.0f) { + // NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be + // multimodal inputs that should not be scaled. + if (ubatch.token && hparams.f_embedding_scale != 0.0f) { + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } @@ -2088,17 +2166,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { inp->self_kq_mask_swa = nullptr; inp->self_kq_mask_swa_cnv = nullptr; @@ -2175,7 +2256,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl( inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); @@ -2282,7 +2363,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } return inp; @@ -2354,6 +2435,82 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + ggml_tensor * top_k, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx->get_mla(); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs_mla(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask_mla(); + + // prepare new kq mask - starts filled with -INFINITY + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + + // reshape KQ mask into tensor with rows of size 1: + // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream] + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1] + ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream] + // this will be our source of zero values for unmasking top k mask elements + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + // modify KQ mask by unmasking elements that are in top_k indices + // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1]) + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d); + + // reshape to restore the original shape of KQ mask: + // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream] + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur, wo_s); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2446,10 +2603,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + inp->cross_kq_mask_cnv = inp->cross_kq_mask; return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } @@ -2497,6 +2657,34 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { + const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx); + + auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur); + + { + inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); + inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla; + } + + { + inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); + + // ensure F32 mask + auto cparams_copy = cparams; + cparams_copy.flash_attn = false; + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy); + inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid; + + inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); + } + + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); +} + // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. @@ -2510,7 +2698,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } { @@ -2520,7 +2708,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); @@ -2689,7 +2877,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask; } { @@ -2697,7 +2885,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa; } auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index bf6778237e6..6793846e3ea 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -22,6 +22,7 @@ struct llama_layer; struct llama_memory_context_i; class llama_kv_cache_context; +class llama_kv_cache_dsa_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -35,7 +36,8 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DECODER_MTP, }; -enum llm_ffn_op_type { +enum llm_ffn_op_type : int { + LLM_FFN_NONE = 0, // sentinel: unset; archs must assign before use LLM_FFN_SILU, LLM_FFN_GELU, LLM_FFN_RELU, @@ -121,6 +123,23 @@ class llm_graph_input_embd : public llm_graph_input_i { const int64_t n_embd = 0; }; +// similar to llm_graph_input_embd but with an additional hidden state input +class llm_graph_input_embd_h : public llm_graph_input_i { +public: + llm_graph_input_embd_h(int64_t n_embd) : n_embd(n_embd) {} + virtual ~llm_graph_input_embd_h() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + ggml_tensor * h = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; +}; + class llm_graph_input_pos : public llm_graph_input_i { public: llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} @@ -274,10 +293,10 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // n_tokens == n_batch - ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -307,8 +326,8 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] // note: assumes v_rot^2 == I ggml_tensor * self_k_rot = nullptr; @@ -347,8 +366,8 @@ class llm_graph_input_attn_k : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -356,6 +375,44 @@ class llm_graph_input_attn_k : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +class llm_graph_input_attn_k_dsa : public llm_graph_input_i { +public: + llm_graph_input_attn_k_dsa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_dsa_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k_dsa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs_mla() const { return self_k_idxs_mla; } + ggml_tensor * get_k_idxs_lid() const { return self_k_idxs_lid; } + + ggml_tensor * get_kq_mask_mla() const { return self_kq_mask_mla_cnv; } + ggml_tensor * get_kq_mask_lid() const { return self_kq_mask_lid; } + + ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot_lid = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_dsa_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -385,10 +442,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; @@ -411,8 +468,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -646,7 +703,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } - ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } + ggml_tensor * get_h_nextn() const { return t_h_nextn; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -675,7 +732,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; std::map<llama_seq_id, ggml_tensor*> t_candidates; @@ -727,6 +784,7 @@ struct llm_graph_context { const int64_t n_embd; const int64_t n_layer; + const int64_t n_layer_nextn; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_head; @@ -956,6 +1014,23 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_k_dsa * build_attn_inp_k_dsa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * top_k, // [n_indexer_top_k, n_tokens] + float kq_scale, + int il) const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 2239309c8fb..2bf57687382 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -7,19 +7,39 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0); } } else { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); } } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_swa_impl[il] = false; + } +} + +void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } + } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_recr_impl[il] = false; + } } bool llama_hparams::is_swa_any() const { - for (uint32_t il = 0; il < n_layer; ++il) { - if (swa_layers[il]) { + for (uint32_t il = 0; il < n_layer_all; ++il) { + if (is_swa_impl[il]) { return true; } } @@ -28,7 +48,7 @@ bool llama_hparams::is_swa_any() const { } uint32_t llama_hparams::n_head(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_arr[il]; } @@ -36,7 +56,7 @@ uint32_t llama_hparams::n_head(uint32_t il) const { } uint32_t llama_hparams::n_head_kv(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_kv_arr[il]; } @@ -44,7 +64,7 @@ uint32_t llama_hparams::n_head_kv(uint32_t il) const { } uint32_t llama_hparams::n_ff(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_ff_arr[il]; } @@ -63,7 +83,7 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { } uint32_t llama_hparams::n_rot(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_rot_swa : n_rot_full; } @@ -71,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { @@ -85,7 +109,7 @@ uint32_t llama_hparams::n_embd_out() const { } uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; } @@ -93,7 +117,7 @@ uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { } uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; } @@ -114,7 +138,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { bool llama_hparams::is_n_embd_k_gqa_variable() const { const uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_k_gqa(il)) { return true; } @@ -125,7 +149,7 @@ bool llama_hparams::is_n_embd_k_gqa_variable() const { bool llama_hparams::is_n_embd_v_gqa_variable() const { const uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_v_gqa(il)) { return true; } @@ -136,7 +160,7 @@ bool llama_hparams::is_n_embd_v_gqa_variable() const { uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_k_gqa(il)); } @@ -145,7 +169,7 @@ uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t llama_hparams::n_embd_v_gqa_max() const { uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_v_gqa(il)); } @@ -193,12 +217,12 @@ uint32_t llama_hparams::n_embd_s() const { return ssm_d_state * ssm_d_inner; } -bool llama_hparams::is_recurrent(uint32_t il) const { - if (il < n_layer) { - return recurrent_layer_arr[il]; +bool llama_hparams::is_recr(uint32_t il) const { + if (il < n_layer_all) { + return is_recr_impl[il]; } - GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } uint32_t llama_hparams::n_pos_per_embd() const { @@ -206,11 +230,11 @@ uint32_t llama_hparams::n_pos_per_embd() const { } bool llama_hparams::is_swa(uint32_t il) const { - if (il < n_layer) { - return swa_layers[il]; + if (il < n_layer_all) { + return is_swa_impl[il]; } - GGML_ABORT("fatal error"); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } bool llama_hparams::is_mla() const { @@ -229,12 +253,6 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { - if (kv_only_nextn) { - // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; - // the leading trunk blocks are not executed in this graph. - return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); - } - if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; @@ -247,16 +265,8 @@ bool llama_hparams::has_kv(uint32_t il) const { return true; } -uint32_t llama_hparams::n_layer_kv() const { - uint32_t res = 0; - - for (uint32_t il = 0; il < n_layer; ++il) { - if (has_kv(il)) { - res++; - } - } - - return res; +uint32_t llama_hparams::n_layer() const { + return n_layer_all - n_layer_nextn; } bool llama_hparams::use_mrope() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index e2d051edc6c..032944cb481 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -23,6 +23,9 @@ enum llama_swa_type { LLAMA_SWA_TYPE_SYMMETRIC = 3, }; +// forward declaration; full definition in llama-graph.h +enum llm_ffn_op_type : int; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -34,6 +37,9 @@ struct llama_hparams_convnext { }; struct llama_hparams { + // note: use the `_impl` suffix to avoid name conflict between members and getters + // for example: n_embd_out() vs n_embd_out_impl + bool vocab_only; bool no_alloc; bool rope_finetuned; @@ -42,12 +48,15 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + uint32_t n_layer_all; + uint32_t n_layer_nextn = 0; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // TODO: this needs to be reworked + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + // different head size for full_attention and SWA layers uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head @@ -90,9 +99,6 @@ struct llama_hparams { uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; uint32_t moe_latent_size = 0; - uint32_t nextn_predict_layers = 0; - - bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) float f_norm_eps; float f_norm_rms_eps; @@ -134,11 +140,15 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == 1, then layer il is SWA - // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) + + // if is_swa_impl[il] == 1, then layer il is SWA + // if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense // note: using uint32_t type for compatibility reason - std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers; + std::array<uint32_t, LLAMA_MAX_LAYERS> is_swa_impl; + + // for hybrid state space models + std::array<uint32_t, LLAMA_MAX_LAYERS> is_recr_impl; // for State Space Models uint32_t ssm_d_conv = 0; @@ -150,9 +160,6 @@ struct llama_hparams { // for Kimi Linear KDA uint32_t n_embd_head_kda = 0; - // for hybrid state space models - std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -178,6 +185,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // input embedding dimension (0 = use n_embd) + uint32_t n_embd_inp_impl = 0; + // output embedding dimension (0 = use n_embd) uint32_t n_embd_out_impl = 0; @@ -212,8 +222,19 @@ struct llama_hparams { uint32_t indexer_top_k = 0; // qwen3vl deepstack + // When parsed from GGUF, this implies the first N layers consume the first + // N deepstack embeddings. Use deepstack_mapping_arr if you need a more + // complex mapping. If using deepstack_mapping_arr, also make sure to set + // n_deepstack_layers to the number of unique deepstack layers so that + // n_embd_imp is accurate (see granite.cpp). + // TODO: can be expressed via the `new n_embd_inp_impl` and remove this param uint32_t n_deepstack_layers = 0; + // deepstack layer array (Granite4 Vision) + // -1 => no deepstack + // >=0 => input embedding index for deepstack injection + std::array<int32_t, LLAMA_MAX_LAYERS> deepstack_mapping_arr; + // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; @@ -227,6 +248,14 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // Resolved FFN gated activation flavor for archs that read + // `<arch>.hidden_activation` from the GGUF (e.g. ModernBert derivatives). + // Defaults to LLM_FFN_NONE (sentinel = 0); the mapping from the GGUF + // string to a real op is done at hparam-load time via + // llm_ffn_op_type_from_string() in llama-model.cpp, mirroring how + // rope_scaling_type_train is handled. + enum llm_ffn_op_type llm_ffn_op; + // Step35: optional per-layer clamps for (Swi)GLU std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared expert @@ -255,6 +284,13 @@ struct llama_hparams { // return true if one of the layers is SWA bool is_swa_any() const; + bool is_swa(uint32_t il) const; + + void set_recr_pattern(uint32_t n_pattern, bool dense_first = false); + + // whether or not the given layer is recurrent (for hybrid models) + bool is_recr(uint32_t il) const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; @@ -296,13 +332,8 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_s() const; - // whether or not the given layer is recurrent (for hybrid models) - bool is_recurrent(uint32_t il) const; - uint32_t n_pos_per_embd() const; - bool is_swa(uint32_t il) const; - // note: currently only support if either all or none of the layers are MLA bool is_mla() const; @@ -311,8 +342,8 @@ struct llama_hparams { bool has_kv(uint32_t il) const; - // number of layers for which has_kv() returns true - uint32_t n_layer_kv() const; + // number of effective layers (excludes nextn layers) + uint32_t n_layer() const; // note that this function uses different SWA parameters from those in the hparams // note: inlined on purpose for performance reasons diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index e4f35c8e53d..7923c3f7ed5 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -3,6 +3,7 @@ #include "ggml.h" // for ggml_log_level #include <string> +#include <type_traits> #include <vector> #ifdef __GNUC__ @@ -40,6 +41,19 @@ struct no_init { no_init() = default; }; +template <typename dst_t, typename src_t> +static inline dst_t llama_cast(src_t v) { + if constexpr (std::is_same_v<src_t, dst_t>) { + return v; + } else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { + return ggml_fp16_to_fp32(v); + } else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { + return ggml_fp32_to_fp16(v); + } else { + static_assert(std::is_same_v<dst_t, void>, "unsupported type combination"); + } +} + struct time_meas { time_meas(int64_t & t_acc, bool disable = false); ~time_meas(); diff --git a/examples/talk-llama/llama-kv-cache-dsa.cpp b/examples/talk-llama/llama-kv-cache-dsa.cpp new file mode 100644 index 00000000000..916ab653756 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.cpp @@ -0,0 +1,261 @@ +#include "llama-kv-cache-dsa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include <algorithm> +#include <cassert> + +// +// llama_kv_cache_dsa +// + +llama_kv_cache_dsa::llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { + + LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); + + kv_mla = std::make_unique<llama_kv_cache>( + model, model.hparams, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); + + // we use llama_kv_cache for caching indexer keys + // by hand-tweaking some hparams we fool it to create + // indexer key cache tensors with correct dimensions + // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 + + // DSA lightning indexer uses MQA with single key head + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; + + LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); + + kv_lid = std::make_unique<llama_kv_cache>( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); +} + +void llama_kv_cache_dsa::clear(bool data) { + kv_mla->clear(data); + kv_lid->clear(data); +} + +bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_mla->seq_rm(seq_id, p0, p1); + res = res & kv_lid->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { + kv_mla->seq_keep(seq_id); + kv_lid->seq_keep(seq_id); +} + +void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_mla->seq_add(seq_id, p0, p1, shift); + kv_lid->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_mla->seq_div(seq_id, p0, p1, d); + kv_lid->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { + return kv_mla->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { + return kv_mla->seq_pos_max(seq_id); +} + +std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown(); + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + balloc.split_reset(); + + std::vector<llama_ubatch> ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos_mla = kv_mla->prepare(ubatches); + if (sinfos_mla.empty()) { + break; + } + + auto sinfos_lid = kv_lid->prepare(ubatches); + if (sinfos_lid.empty()) { + break; + } + + assert(sinfos_mla.size() == sinfos_lid.size()); + + return std::make_unique<llama_kv_cache_dsa_context>( + this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); + } while (false); + + return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_full() { + return std::make_unique<llama_kv_cache_dsa_context>(this); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize); +} + +bool llama_kv_cache_dsa::get_can_shift() const { + return kv_mla->get_can_shift() && + kv_lid->get_can_shift() && + kv_mla->get_size() == kv_lid->get_size(); +} + +void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + kv_mla->state_write(io, seq_id, flags); + kv_lid->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + kv_mla->state_read(io, seq_id, flags); + kv_lid->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_kv_cache_dsa::get_mla() const { + return kv_mla.get(); +} + +llama_kv_cache * llama_kv_cache_dsa::get_lid() const { + return kv_lid.get(); +} + +// +// llama_kv_cache_dsa_context +// + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv) : + ctx_mla(kv->get_mla()->init_full()), + ctx_lid(kv->get_lid()->init_full()), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize) : + ctx_mla(kv->get_mla()->init_update(lctx, optimize)), + ctx_lid(kv->get_lid()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_mla, + slot_info_vec_t sinfos_lid, + std::vector<llama_ubatch> ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), + ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; + +bool llama_kv_cache_dsa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_mla->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_mla->apply(); + res = res & ctx_lid->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_mla.get()); +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_lid.get()); +} diff --git a/examples/talk-llama/llama-kv-cache-dsa.h b/examples/talk-llama/llama-kv-cache-dsa.h new file mode 100644 index 00000000000..e2b330993b8 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.h @@ -0,0 +1,138 @@ +#pragma once + +#include "llama-kv-cache.h" + +#include <vector> + +// +// llama_kv_cache_dsa +// + +// utilizes two instances of llama_kv_cache: +// - the first instance is for caching key tensors of the model, +// - the second instance is for caching lightning indexer key tensors + +class llama_kv_cache_dsa : public llama_memory_i { +public: + llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsa specific API + // + + llama_kv_cache * get_mla() const; + llama_kv_cache * get_lid() const; + +private: + // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it + llama_hparams hparams_lid; + const uint32_t n_stream = 1; + + std::unique_ptr<llama_kv_cache> kv_mla; + std::unique_ptr<llama_kv_cache> kv_lid; +}; + +class llama_kv_cache_dsa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // used for errors + llama_kv_cache_dsa_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv); + + // used to create an update context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize); + + // used to create a batch processing context from a batch + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_ik, + std::vector<llama_ubatch> ubatches); + + virtual ~llama_kv_cache_dsa_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsa_context specific API + // + + const llama_kv_cache_context * get_mla() const; + const llama_kv_cache_context * get_lid() const; + +private: + //llama_kv_cache_dsa * kv; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + const llama_memory_context_ptr ctx_mla; + const llama_memory_context_ptr ctx_lid; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 26e2cb4270b..aa1b1b72ebe 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + if (mem_other) { + mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base(); + } + + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa(); + } + kv_base = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/examples/talk-llama/llama-kv-cache-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h index 70ab22f0d60..dfafc1ef510 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.h +++ b/examples/talk-llama/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i { uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index a49a055a630..2802103bdd8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -79,6 +79,7 @@ static ggml_tensor * ggml_mul_mat_aux( llama_kv_cache::llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -89,14 +90,30 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : + model(model), hparams(hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type), + other(static_cast<llama_kv_cache *>(mem_other)), + v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()), + v_cells(*v_cells_impl) { + + // shared cells view the source cache's K/V tensors, so the cell count + // follows the source allocation: a fitted target can be smaller than the + // draft default and oversized views would overflow the source tensors + if (other) { + const uint32_t size_other = other->get_size(); + if (kv_size != size_other) { + LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other); + kv_size = size_other; + } + } GGML_ASSERT(kv_size % n_pad == 0); - const uint32_t n_layer_kv = hparams.n_layer_kv(); + const uint32_t n_layer = hparams.n_layer_all; // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { @@ -111,7 +128,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -159,7 +176,7 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); continue; @@ -170,6 +187,24 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -229,7 +264,7 @@ llama_kv_cache::llama_kv_cache( if (reuse) { LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { const int32_t il_reuse = reuse(il); if (il_reuse < 0) { @@ -253,7 +288,7 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { ggml_backend_buffer_t buf; - if (model.hparams.no_alloc) { + if (hparams.no_alloc) { buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it @@ -281,23 +316,37 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); - const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; - if (attn_rot_disable) { - LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); - } + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + n_embd_head_k_all = other->n_embd_head_k_all; + n_embd_head_v_all = other->n_embd_head_v_all; + + attn_rot_k = other->attn_rot_k; + attn_rot_v = other->attn_rot_v; + } else { + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; - attn_rot_k = - !attn_rot_disable && - n_embd_head_k_all > 0 && - ggml_is_quantized(type_k) && - hparams.n_embd_head_k() % 64 == 0; + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } - attn_rot_v = - !attn_rot_disable && - n_embd_head_v_all > 0 && - ggml_is_quantized(type_v) && - hparams.n_embd_head_v() % 64 == 0; + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + } LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); @@ -341,6 +390,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -404,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -491,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -513,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -558,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -592,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_min(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -600,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { } llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_max(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -740,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ } bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + bool updated = false; auto * sched = lctx->get_sched(); @@ -1015,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1430,8 +1524,8 @@ struct args_set_input_kq_mask { int64_t n_tps; }; -template<bool causal, bool swa, bool is_2d, bool alibi> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d, bool alibi> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { //const auto & hparams = args.hparams; const auto & ubatch = args.ubatch; @@ -1445,6 +1539,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * const int64_t n_stream = args.n_stream; const int64_t n_tps = args.n_tps; + const T mask_keep = llama_cast<T>(0.0f); + const T mask_drop = llama_cast<T>(-INFINITY); + // the min position in the batch for each sequence llama_pos seq_pos_min[LLAMA_MAX_SEQ]; std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); @@ -1563,46 +1660,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } if (alibi) { - data[idst + j] = -std::abs(p0 - p1); + data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1))); } else { - data[idst + j] = 0.0f; + data[idst + j] = mask_keep; } continue; skip: - data[idst + j] = -INFINITY; + data[idst + j] = mask_drop; } } } } -template<bool causal, bool swa, bool is_2d> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool alibi = args.hparams.use_alibi; if (alibi) { - set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data); } } -template<bool causal, bool swa> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool is_2d = args.ubatch->is_pos_2d(); if (is_2d) { - set_input_kq_mask_impl<causal, swa, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, false>(args, data); } } -template<bool causal> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; if (swa) { - set_input_kq_mask_impl<causal, true> (args, data); + set_input_kq_mask_impl<T, causal, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, false>(args, data); + } +} + +template<typename T> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) { + if (causal_attn) { + set_input_kq_mask_impl<T, true> (args, data); } else { - set_input_kq_mask_impl<causal, false>(args, data); + set_input_kq_mask_impl<T, false>(args, data); } } @@ -1610,7 +1716,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; const int64_t n_kv = dst->ne[0]; const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch @@ -1634,10 +1739,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u /*.n_tps =*/ n_tps, }; - if (causal_attn) { - set_input_kq_mask_impl<true> (args, data); + if (dst->type == GGML_TYPE_F16) { + set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn); } else { - set_input_kq_mask_impl<false>(args, data); + set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn); } //const int64_t t_end = ggml_time_us(); @@ -1798,6 +1903,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1843,6 +1951,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1859,7 +1972,19 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla uint32_t cell_range_begin = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + bool add_cell = true; + + add_cell = add_cell && !cells.is_empty(i); + add_cell = add_cell && (seq_id == -1 || cells.seq_has(i, seq_id)); + + // check the cell is not SWA-masked + if (add_cell && seq_id != -1) { + const bool is_masked = llama_hparams::is_masked_swa(n_swa, swa_type, cells.pos_get(i), cells.seq_pos_max(seq_id)); + + add_cell = !is_masked; + } + + if (add_cell) { ++cell_count; if (cell_range_begin == cells.size()) { cell_range_begin = i; @@ -1896,6 +2021,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); @@ -2112,7 +2242,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 sinfo = find_slot(ubatch, false); if (sinfo.empty()) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to find %d available cells in kv cache\n", __func__, cell_count); return false; } diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0b62dc7b232..3d68f98c142 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -93,8 +93,12 @@ class llama_kv_cache : public llama_memory_i { using slot_info_vec_t = std::vector<slot_info>; + // TODO: refactor the memory instances to not depend on `llama_model` + // instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly + // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -105,8 +109,10 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -260,7 +266,12 @@ class llama_kv_cache : public llama_memory_i { // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector<uint32_t> v_heads; - std::vector<llama_kv_cells> v_cells; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + + std::shared_ptr<llama_kv_cells_vec> v_cells_impl; + + llama_kv_cells_vec & v_cells; // maps from a sequence id to a stream id std::vector<uint32_t> seq_to_stream; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 10063bf4272..fddd31a0b21 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -531,3 +531,5 @@ class llama_kv_cells { } } }; + +using llama_kv_cells_vec = std::vector<llama_kv_cells>; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 72f5c2fea72..c7d4bcd413e 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +59,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 33b3b395e0c..f2d49cbce54 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -33,6 +33,7 @@ llama_memory_hybrid::llama_memory_hybrid( hparams(model.hparams), mem_attn(new llama_kv_cache( model, + model.hparams, type_k, type_v, v_trans, @@ -43,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +60,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index ec5dc5835dd..6a4892fb471 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -26,7 +26,7 @@ llama_memory_recurrent::llama_memory_recurrent( uint32_t n_seq_max, uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; + const int32_t n_layer = hparams.n_layer(); head = 0; size = mem_size; @@ -863,7 +863,7 @@ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std:: void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { const uint32_t s_trans = 0; - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = hparams.n_layer(); io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); @@ -1047,8 +1047,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell io.read(&s_trans, sizeof(s_trans)); io.read(&n_layer, sizeof(n_layer)); - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + if (n_layer != hparams.n_layer()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer()); return false; } if (cell_count > size) { diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4ad1612e45b..db825396645 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -23,6 +23,8 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + llama_memory_t mem_other; }; enum llama_memory_status { @@ -76,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function<int32_t(int32_t il)>; + using layer_share_cb = std::function<int32_t(int32_t il)>; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index c645d0785ab..0d1cf3cc33b 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -146,7 +146,7 @@ namespace GGUFMeta { const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { arr_type, - size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_n(ctx, k), arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } @@ -393,6 +393,7 @@ namespace GGUFMeta { } template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); + template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { @@ -445,7 +446,7 @@ namespace GGUFMeta { } if (n > N_MAX) { - throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str())); } if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { @@ -502,9 +503,9 @@ namespace GGUFMeta { } // TODO: this is not very clever - figure out something better - template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<int, 4>> (enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); - template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( @@ -1050,10 +1051,10 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (it == ctx_map.end()) { // one ggml context per buffer type int max_n_tensors = n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer()*2; // duplicated rope freq tensors if (files.empty()) { - max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses + max_n_tensors += hparams.n_layer()*256; // this should be well above what any model actually uses } const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 528e4c9c069..67d4a9df0f0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -14,9 +14,6 @@ bool llama_model_saver_supports_arch(llm_arch arch) { switch (arch) { - case LLM_ARCH_QWEN3NEXT: - case LLM_ARCH_QWEN35: - case LLM_ARCH_QWEN35MOE: case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: @@ -29,6 +26,7 @@ bool llama_model_saver_supports_arch(llm_arch arch) { case LLM_ARCH_APERTUS: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: return false; default: return true; @@ -79,7 +77,7 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template <typename Container> void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { GGML_ASSERT(model != nullptr || !per_layer); - const size_t n_values = per_layer ? size_t(model->hparams.n_layer) : value.size(); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer()) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -106,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); } else if (std::is_same<typename Container::value_type, uint32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same<typename Container::value_type, bool>::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values); } else if (std::is_same<typename Container::value_type, int32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); } else if (std::is_same<typename Container::value_type, float>::value) { @@ -206,7 +206,7 @@ void llama_model_saver::add_kv_from_model() { if (hparams.n_embd_out_impl > 0) { add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } - add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -227,8 +227,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); - add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn); add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); + add_kv(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); @@ -244,7 +245,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -278,6 +279,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0c3e03a61dc..4f12e0949ac 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -10,6 +10,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -80,6 +81,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_mpt(params); case LLM_ARCH_STABLELM: return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); case LLM_ARCH_QWEN: return new llama_model_qwen(params); case LLM_ARCH_QWEN2: @@ -136,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -172,6 +177,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_deepseek2(params); case LLM_ARCH_DEEPSEEK2OCR: return new llama_model_deepseek2ocr(params); + case LLM_ARCH_DEEPSEEK32: + return new llama_model_deepseek32(params); case LLM_ARCH_GLM_DSA: return new llama_model_glm_dsa(params); case LLM_ARCH_MISTRAL4: @@ -368,10 +375,10 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // count only the same type of previous layers to avoid this auto get_il_eff = [&](const size_t il){ size_t ret = 0; - const bool il_is_recurrent = hparams.is_recurrent(il); - const bool il_is_swa = hparams.is_swa(il); + const bool il_is_recr = hparams.is_recr(il); + const bool il_is_swa = hparams.is_swa(il); for (size_t il_prev = 0; il_prev < il; il_prev++) { - ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa; } return ret; }; @@ -393,7 +400,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str rotation = get_il_eff(il) % ud->n_devices; } else { il = 0; - rotation = hparams.n_layer % ud->n_devices; + rotation = hparams.n_layer() % ud->n_devices; } const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); if (tensor_axis_0 == nullptr) { @@ -407,16 +414,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str auto get_tensor_config = [&]() -> tensor_config { // standard attention if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qkv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if ( std::regex_match(tensor_name, pattern_qkv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qk_norm)) { return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); @@ -432,7 +439,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str } if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); @@ -485,7 +492,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); }; - auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<int64_t> { + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<std::pair<int64_t, uint32_t>> { if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { const int64_t head_k_dim = hparams.ssm_d_state; const int64_t head_v_dim = hparams.ssm_d_state; @@ -500,26 +507,26 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return {key_dim, key_dim, value_dim}; + return {{key_dim, 2}, {value_dim, 1}}; } } else { const int64_t head_ratio = n_v_heads / n_k_heads; if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return std::vector<int64_t>(2 + head_ratio, key_dim); + return {{key_dim, 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { - return std::vector<int64_t>(head_ratio, key_dim); + return {{key_dim, head_ratio}}; } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { - return std::vector<int64_t>(head_ratio, n_k_heads); + return {{n_k_heads, head_ratio}}; } if (std::regex_match(tensor_name, pattern_r_cache)) { - return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + return {{key_dim * (hparams.ssm_d_conv - 1), 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_s_cache)) { - return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim); + return {{n_k_heads * head_v_dim * head_v_dim, head_ratio}}; } } @@ -527,9 +534,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { @@ -537,21 +544,23 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); - return {n_embd, n_embd_gqa, n_embd_gqa}; + return {{n_embd, 1}, {n_embd_gqa, 2}}; } if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; }; - auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<int64_t> & segments) -> std::vector<int64_t> { - if (hparams.is_recurrent(il)) { + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> { + // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used + if (hparams.is_recr(il)) { // linear attention - const int64_t head_dim = hparams.ssm_d_state; - const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + const int64_t head_dim = hparams.ssm_d_state; + const int64_t blck_size_perf = std::lcm(blck_size, 128); + const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim); if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { return std::vector<int64_t>(segments.size(), granularity_qkv); @@ -573,17 +582,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // regular attention const uint32_t n_gqa = hparams.n_gqa(il); const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + + // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization + int64_t blck_size_perf = blck_size; + while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) { + blck_size_perf *= 2; + } + if (std::regex_match(tensor_name, pattern_attn_sinks)) { GGML_ASSERT(segments.size() == 1); - return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa}; } - const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf); if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { GGML_ASSERT(segments.size() == 1); // some models have Q gate tensors, for those cases the granularity needs to be doubled: if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { - return {std::lcm(2*n_embd_q, blck_size)}; + return {std::lcm(2*n_embd_q, blck_size_perf)}; } return {granularity_q}; } @@ -600,16 +616,17 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return {granularity_kv}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { - GGML_ASSERT(segments.size() == 3); - return {granularity_q, granularity_kv, granularity_kv}; + GGML_ASSERT(segments.size() == 2); + return {granularity_q, granularity_kv}; } } // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { - GGML_ASSERT(segments.size() <= 2); - return std::vector<int64_t>(segments.size(), blck_size); + const int64_t blck_size_perf = std::lcm(blck_size, 128); + GGML_ASSERT(segments.size() == 1); + return {blck_size_perf}; } // everything else @@ -622,7 +639,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_config tc = get_tensor_config(); split_state.axis = tc.axis; if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { - const int64_t ne_full = tensor->ne[split_state.axis]; const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); const float * tensor_split = ud->model->tensor_split(); std::vector<float> tensor_split_scan; @@ -633,12 +649,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_split_scan[j] += tensor_split_scan[j - 1]; } } - const std::vector<int64_t> segments = get_split_segments(split_state.axis, tc.il); + const std::vector<std::pair<int64_t, uint32_t>> segments = get_split_segments(split_state.axis, tc.il); const std::vector<int64_t> granularity = get_split_granularity(blck_size, tc.il, segments); for (size_t is = 0; is < segments.size(); is++) { - const int64_t ne_s = segments[is]; - const int64_t g_s = granularity[is]; - GGML_ASSERT(ne_full % g_s == 0); + const int64_t ne_s = segments[is].first; + const uint32_t nr_s = segments[is].second; + const int64_t g_s = granularity[is]; int64_t low = 0; size_t j = 0; for (; j < ud->n_devices - 1; j++) { @@ -651,10 +667,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str low = high; } split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + split_state.nr[is] = nr_s; } split_state.n_segments = segments.size(); } else { memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.nr[0] = 1; split_state.n_segments = 1; } return split_state; @@ -758,6 +776,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; @@ -779,6 +798,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_685B_A37B: return "685B.A37B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -815,6 +835,28 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } +// Maps the GGUF `<arch>.hidden_activation` string to the FFN op type used by the +// graph builders. Only gated activations that map cleanly to llm_ffn_op_type are +// listed; unrecognized values fall back to GeGLU, which matches the historical +// default for ModernBert-style architectures. +static const std::map<std::string, llm_ffn_op_type> LLM_FFN_OP_TYPES_FROM_STRING = { + { "gelu", LLM_FFN_GEGLU }, + { "geglu", LLM_FFN_GEGLU }, + { "silu", LLM_FFN_SWIGLU }, + { "swish", LLM_FFN_SWIGLU }, + { "swiglu", LLM_FFN_SWIGLU }, + { "relu", LLM_FFN_RELU }, + { "reglu", LLM_FFN_REGLU }, +}; + +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback) { + const auto it = LLM_FFN_OP_TYPES_FROM_STRING.find(name); + if (it != LLM_FFN_OP_TYPES_FROM_STRING.end()) { + return it->second; + } + return fallback; +} + // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector<llama_device> & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -1002,7 +1044,7 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); @@ -1044,28 +1086,29 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill( - hparams.recurrent_layer_arr.begin(), - hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0); std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer(), false); + + // Populate deepstack_mapping_arr - initialized to -1 (no deepstack) + std::fill(hparams.deepstack_mapping_arr.begin(), hparams.deepstack_mapping_arr.end(), -1); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer(), false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -1164,7 +1207,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { const auto & use_mlock = params.use_mlock; const auto & tensor_split = params.tensor_split; - const int n_layer = hparams.n_layer; + const int n_layer_all = hparams.n_layer_all; const int n_gpu_layers = this->n_gpu_layers(); const bool use_mmap_buffer = true; @@ -1221,10 +1264,10 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); + const int i_gpu_start = std::max(n_layer_all + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer_all + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); + const bool is_swa = il < n_layer_all && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -1240,13 +1283,13 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer.resize(n_layer_all); + for (int il = 0; il < n_layer_all; ++il) { pimpl->dev_layer[il] = get_layer_buft_list(il); } // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); + pimpl->dev_output = get_layer_buft_list(n_layer_all); const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; @@ -1262,14 +1305,14 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { throw std::runtime_error("model has expert layers but no expert layers are used"); } - layers.resize(n_layer); + layers.resize(n_layer_all); // call the per-model loading function load_arch_tensors(ml); // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) // this avoids having to add scale loading to every architecture - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; // attention weight scales (per-tensor, shape {1}) @@ -1527,7 +1570,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + const int n_gpu = std::min(n_gpu_layers, n_layer_all); int n_repeating = n_gpu; if (n_repeating > 0) { @@ -1536,8 +1579,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; + const int max_backend_supported_layers = n_layer_all + 1; + const int max_offloadable_layers = n_layer_all + 1; LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); } @@ -1606,7 +1649,8 @@ const float * llama_model::tensor_split() const { } uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; + // note: plus 1 for the "output" layer + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer_all + 1; } llama_split_mode llama_model::split_mode() const { @@ -1639,10 +1683,10 @@ uint64_t llama_model::n_elements() const { void llama_model::print_info() const { const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) { + auto print_f = [](const std::function<int32_t(uint32_t)> & f, uint32_t n) { bool is_var = false; - std::vector<uint32_t> v; + std::vector<int32_t> v; for (uint32_t i = 0; i < n; ++i) { v.push_back(f(i)); if (v[i] != v[0]) { @@ -1675,19 +1719,21 @@ void llama_model::print_info() const { if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer()); + LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); @@ -1695,7 +1741,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); @@ -1716,6 +1762,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + if (arch == LLM_ARCH_GRANITE && + std::any_of(hparams.deepstack_mapping_arr.begin(), + hparams.deepstack_mapping_arr.end(), + [](const auto & entry) { return entry >= 0; })) { + LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__, + print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; }, + hparams.n_layer_all).c_str()); + } // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); @@ -1769,7 +1823,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -1787,7 +1841,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -1818,7 +1876,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_nextn = %d\n", __func__, hparams.n_layer_nextn); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { @@ -1957,6 +2015,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; + case LLM_ARCH_DEEPSEEK32: + { + res = new llama_kv_cache_dsa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } break; // Models that need standard caching should rely on recurrent/hybrid // checks default: @@ -1983,22 +2058,21 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; + filter_attn = [&](uint32_t) { return true; }; + filter_recr = [&](uint32_t) { return true; }; } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_attn = [&](uint32_t il) { + return !hparams.is_recr(il) && hparams.n_ff(il) == 0; }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_recr = [&](uint32_t il) { + return hparams.is_recr(il) && hparams.n_ff(il) == 0; }; } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter_attn = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && !hparams.is_recurrent(il); + filter_attn = [&](uint32_t il) { + return il < hparams.n_layer() && !hparams.is_recr(il); }; - filter_recr = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && hparams.is_recurrent(il); + filter_recr = [&](uint32_t il) { + return il < hparams.n_layer() && hparams.is_recr(il); }; } @@ -2043,13 +2117,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_recr */ std::move(filter_recr)); } } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + reuse = [&](uint32_t il) { + GGML_ASSERT(hparams.n_layer_kv_from_start >= 2); + + if (il >= (uint32_t)hparams.n_layer_kv_from_start) { + return hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); } return -1; @@ -2057,32 +2134,73 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } if (mtp_on_hybrid_qwen35) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } + + if (arch == LLM_ARCH_STEP35 && hparams.n_layer_nextn > 0) { + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP) { + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } else { + filter = [&](uint32_t il) { return il < hparams.n_layer(); }; + } } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - filter, - reuse); + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + llama_memory_t mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_other) - 2; + } + + return llama_model_n_layer(model_other) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_other, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + filter, + reuse, + share); + } } else { GGML_ASSERT(!hparams.is_swa_any()); res = new llama_kv_cache( *this, + hparams, params.type_k, params.type_v, !cparams.flash_attn, @@ -2093,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + nullptr, filter, + nullptr, nullptr); } } @@ -2181,7 +2301,7 @@ int32_t llama_model_n_embd_out(const llama_model * model) { } int32_t llama_model_n_layer(const llama_model * model) { - return model->hparams.n_layer; + return model->hparams.n_layer(); } int32_t llama_model_n_head(const llama_model * model) { @@ -2272,6 +2392,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2OCR: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -2325,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -2356,6 +2478,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index b797b8966ac..992c8d9c8fd 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe @@ -137,6 +138,7 @@ enum llm_type { LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_685B_A37B, // DeepSeek V3.2 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, @@ -144,6 +146,10 @@ enum llm_type { std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); +// Map a GGUF activation-name string to llm_ffn_op_type. Returns `fallback` if +// the string is empty or not recognized. +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -542,6 +548,10 @@ struct llama_model { struct ggml_tensor * output_s = nullptr; struct ggml_tensor * output_in_s = nullptr; + // NextN/MTP model-level projections + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; @@ -694,7 +704,9 @@ const char * llm_type_name(llm_type type); // convenience macro for loading local variables for load_tensors() in llama_model_base // note: cast to int64_t since we will use these for the tensor dimensions #define LLAMA_LOAD_LOCALS \ - const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \ + const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ + const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 43e05c3d56f..cf92ce4bb8b 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<t qs.has_tied_embeddings = false; } } - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer(); } // @@ -1348,7 +1348,7 @@ llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * des model->hparams.n_embd = desc->n_embd; model->hparams.n_embd_head_k_full = desc->n_embd_head_k; model->hparams.n_embd_head_v_full = desc->n_embd_head_v; - model->hparams.n_layer = desc->n_layer; + model->hparams.n_layer_all = desc->n_layer; model->hparams.n_expert = desc->n_expert; for (uint32_t i = 0; i < desc->n_layer; i++) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 473becade82..9a4bed49487 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -353,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -432,6 +433,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI: + // Same lookaheads as GPT4O but with \p{M} added so combining marks + // (diacritics) attach to their base letters. Avoids excessive + // backtracking on scripts that use them heavily (Bengali, Hindi, + // Telugu, Thai, ...). See PR #22716 for benchmarks. + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: regex_exprs = { // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" @@ -519,6 +529,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_WHITESPACE: + // whitespace pre-tokenizer (jinaai/jina-embeddings-v2-base-zh) + regex_exprs = { + "\\S+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -747,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_lowercase()); // bos token prepended already // find the longest tokens that form the words @@ -792,7 +809,7 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text) { + static std::vector<std::string> preprocess(const std::string & text, bool lowercase) { const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector<std::string> words(1, ""); @@ -811,7 +828,7 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1671,6 +1688,35 @@ struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { const llama_vocab & vocab; }; +struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { + llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector<llama_token> & output) override { + const bool lowercase = vocab.get_normalizer_lowercase(); + + std::string segment; + auto flush = [&]() { + if (!segment.empty()) { + llm_tokenizer_bpe_session::tokenize(segment, output); + segment.clear(); + } + }; + + for (uint32_t cpt : unicode_cpts_from_utf8(text)) { + // drop whitespace + if (unicode_cpt_flags_from_cpt(cpt).is_whitespace) { + flush(); + } else { + segment += unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + } + } + flush(); + } + +private: + const llama_vocab & vocab; +}; + // // impl // @@ -1751,6 +1797,7 @@ struct llama_vocab::impl { bool remove_extra_whitespaces = false; bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; + bool normalizer_lowercase = true; // Lowercase normalizer (tokenizer.json) std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -1768,6 +1815,8 @@ struct llama_vocab::impl { // set of all tokens that cause "end of generation" std::set<llama_token> special_eog_ids; + std::vector<llama_token> suppress_tokens; + std::unique_ptr<llm_tokenizer> tokenizer; std::vector<char> precompiled_charsmap; @@ -1900,7 +1949,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna" || tokenizer_model == "whitespace") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2105,7 +2154,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( - tokenizer_pre == "gemma4") { + tokenizer_pre == "gemma4" || + tokenizer_pre == "granite-embed-multi-311m") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; } else if ( @@ -2119,6 +2169,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "roberta-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; add_sep = true; + } else if ( + tokenizer_pre == "whitespace") { + pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; + normalizer_lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2211,6 +2265,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-embed-multi-97m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "tiny_aya") { pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; @@ -2269,6 +2328,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2470,6 +2532,19 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } + // Lowercase normalizer flag (consulted by WPM / whitespace BPE) + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + + // suppress tokens + { + const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str()); + if (suppress_idx != -1) { + const int n = gguf_get_arr_n(ctx, suppress_idx); + const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx); + suppress_tokens.assign(data, data + n); + } + } + // auto-detect special tokens by text // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... // for now, we apply this workaround to find the tokens based on their text @@ -3264,6 +3339,8 @@ std::vector<llama_token> llama_vocab::impl::tokenize( std::unique_ptr<llm_tokenizer_bpe_session> session; if (vocab.get_tokenizer_model() == "hybriddna") { session = std::make_unique<llm_tokenizer_hybriddna_session>(vocab, *tok_bpe); + } else if (vocab.get_tokenizer_model() == "whitespace") { + session = std::make_unique<llm_tokenizer_whitespace_session>(vocab, *tok_bpe); } else { session = std::make_unique<llm_tokenizer_bpe_session>(vocab, *tok_bpe); } @@ -3892,6 +3969,14 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } +bool llama_vocab::get_normalizer_lowercase() const { + return pimpl->normalizer_lowercase; +} + +const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { + return pimpl->suppress_tokens; +} + int llama_vocab::max_token_len() const { return pimpl->max_token_len; } diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 8ab77594284..2626ae36e33 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,59 +8,62 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, - LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, - LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, - LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, - LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, - LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, - LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, - LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, - LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, - LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, - LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, - LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, - LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, - LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; @@ -138,6 +141,9 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; + bool get_normalizer_lowercase () const; + + const std::vector<llama_token> & get_suppress_tokens() const; int max_token_len() const; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index dfe30ce8f61..a67fa8039a4 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -225,7 +225,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama } case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back({false, dev}); + if (igpus.empty()) { + igpus.push_back({false, dev}); + } break; case GGML_BACKEND_DEVICE_TYPE_META: GGML_ABORT("fatal error"); @@ -239,8 +241,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama // add GPUs model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); - // add integrated GPUs only if no other devices were found - if (model->devices.empty()) { + // add integrated GPUs only if no discrete GPUs were found + // (RPC servers do not count, otherwise the local iGPU would be dropped on iGPU+RPC setups) + if (gpus.empty()) { model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); } } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index e8374c53b70..27e48067428 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -339,6 +339,7 @@ extern "C" { uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] + uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -387,6 +388,10 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // a source/target/parent context + // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts + struct llama_context * ctx_other; }; struct llama_model_tensor_override { @@ -975,7 +980,11 @@ extern "C" { // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. - LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // + // note: using this can cause extra graph reallocations because it changes the graph topology with MoE models, + // so it is generally not recommended to use in practice. will be removed in the future + DEPRECATED(LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup), + "user code should do warmup runs manually [TAG_LLAMA_GRAPH_NO_WARMUP]"); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index a7c77ee5d28..063b214256e 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -30,7 +30,7 @@ void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 56: type = LLM_TYPE_6B; break; case 32: type = LLM_TYPE_26B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index bec7136521c..6dfb8905fbe 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -2,12 +2,13 @@ void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer()); + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index d086c4717ff..9536e7c5d42 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -4,7 +4,7 @@ void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // Arcee uses the same structure as Llama - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index 27deadffeb7..09ee0f752f0 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -4,7 +4,7 @@ void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 128) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 35: type = LLM_TYPE_10B_128x3_66B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 9bd04127b25..b38b2064785 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4d26081cd5d..585f3614174 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -2,7 +2,7 @@ void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index fe1ae10864b..7faf73c835b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -8,7 +8,7 @@ void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_16B; break; case 88: type = LLM_TYPE_290B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 2f0d44a6259..5000e9c6db8 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -9,17 +9,13 @@ void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -39,9 +35,9 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -78,7 +74,7 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); @@ -112,8 +108,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -146,7 +141,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 3c28f419ccf..53ce29f23ca 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,9 +1,9 @@ #include "models.h" void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 3: type = LLM_TYPE_17M; break; // bge-micro case 6: diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 7e8125deec4..c8330274580 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -3,7 +3,7 @@ void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index 30b0f3d07d0..609d2ddf998 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -3,7 +3,7 @@ void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 4bceaefd63b..4f45acecf84 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -6,7 +6,7 @@ void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_34B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 6766fa71c15..7ae5b938fde 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -2,7 +2,8 @@ void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: { if (hparams.n_head(0) == 16) { type = LLM_TYPE_1_5B; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 274dd3342a7..de53bb98184 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,8 @@ void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 42: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2e231bb3f93..750f57a394e 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -2,7 +2,8 @@ void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index a514cf88fc6..61a5945a194 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -5,6 +5,7 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { uint32_t swa_period = 4; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -12,7 +13,8 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index adf7fcaa20f..94a46188bb8 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -3,7 +3,8 @@ void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index af71c775365..4f5ac4d06a4 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,14 +1,14 @@ #include "models.h" void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { -ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); -switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } } - } void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { LLAMA_LOAD_LOCALS; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 567e3535276..cdfcf29e02f 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -2,7 +2,8 @@ void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; case 162: type = LLM_TYPE_405B; break; diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 1fe54adc13e..a9e8bc51403 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -5,7 +5,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + const bool is_lite = (hparams.n_layer() == 27 || hparams.n_layer() == 26 || (hparams.n_layer() == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); @@ -23,7 +23,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + if ((hparams.n_layer() == 47 || hparams.n_layer() == 48) && n_vocab == 154880) { // GLM 4.7 Lite hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } else { @@ -43,7 +43,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { hparams.f_attn_temp_offset = 0.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_16B; break; case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; @@ -191,8 +191,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); - int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < effective_n_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -366,7 +365,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == effective_n_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp index f9e4c98785c..65d31c31b93 100644 --- a/examples/talk-llama/models/deepseek2ocr.cpp +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -14,7 +14,7 @@ void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/deepseek32.cpp b/examples/talk-llama/models/deepseek32.cpp new file mode 100644 index 00000000000..9a20e2ce907 --- /dev/null +++ b/examples/talk-llama/models/deepseek32.cpp @@ -0,0 +1,499 @@ +#include "models.h" + +#include "llama-kv-cache.h" +#include "llama-kv-cache-dsa.h" + +void llama_model_deepseek32::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-6; // eps for layer norm + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_685B_A37B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek32::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("DEEPSEEK32 architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek32::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deepseek32::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const bool is_mla = hparams.is_mla(); + GGML_ASSERT(is_mla); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); + GGML_UNUSED(n_embd_head_v); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_k_dsa * inp_attn_dsa = build_attn_inp_k_dsa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * qr = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr", il); + + ggml_tensor * top_k = nullptr; + + // lightning indexer + { + ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); + cb(indexer_q, "indexer_q", il); + + // split into {n_embd_indexer_head_rope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_pe = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, 0); + cb(indexer_q_pe, "indexer_q_pe", il); + + // and {n_embd_indexer_head_nope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_nope = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + cb(indexer_q_nope, "indexer_q_nope", il); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_q_pe, "indexer_q_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, n_head, n_tokens} + indexer_q = ggml_concat(ctx0, indexer_q_pe, indexer_q_nope, 0); + cb(indexer_q, "indexer_q", il); + + ggml_tensor * indexer_k = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_k, cur); + cb(indexer_k, "indexer_k", il); + + indexer_k = build_norm(indexer_k, model.layers[il].indexer_k_norm, model.layers[il].indexer_k_norm_b, LLM_NORM, il); + cb(indexer_k, "indexer_k", il); + + // split into {n_embd_indexer_head_rope, 1, n_tokens} + ggml_tensor * indexer_k_pe = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_rope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, 0); + cb(indexer_k_pe, "indexer_k_pe", il); + + // and {n_embd_indexer_head_nope, 1, n_tokens} + ggml_tensor * indexer_k_nope = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_nope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, + ggml_row_size(indexer_k->type, n_embd_indexer_head_nope)); + cb(indexer_k_nope, "indexer_k_nope", il); + + indexer_k_pe = ggml_rope_ext(ctx0, indexer_k_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_k_pe, "indexer_k_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, 1, n_tokens} + indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); + cb(indexer_k, "indexer_k", il); + + // perform Hadamard transform on indexer q and k + indexer_q = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_q); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_k); + cb(indexer_k, "indexer_k", il); + + // store indexer keys to KV cache + const auto * mctx_lid = inp_attn_dsa->mctx->get_lid(); + const auto & k_idxs_lid = inp_attn_dsa->get_k_idxs_lid(); + ggml_build_forward_expand(gf, mctx_lid->cpy_k(ctx0, indexer_k, k_idxs_lid, il)); + + // prepare indexer weights + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_lid->get_k(ctx0, il); + + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + + // calculate indexer kq + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + // ReLU requires contiguous tensors + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + // apply ReLU + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + // pre-scale weights to avoid scaling operations on huge indexer_score tensor + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_embd_indexer_head * n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // multiply scores by indexer weights + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + // sum by q n_indexer_head dimension + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + // permute result to match KQ mask + indexer_score = ggml_cont(ctx0, ggml_permute(ctx0, indexer_score, 2, 1, 0, 3)); + cb(indexer_score, "indexer_score", il); + + // mask indexer scores + ggml_tensor * indexer_kq_mask = inp_attn_dsa->get_kq_mask_lid(); + indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask); + cb(indexer_score, "indexer_score", il); + + // get indices of top k indexer scores + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + } + + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); + cb(q, "q", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = + ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = + ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // MLA attention + { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_dsa, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il); + } + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 435d27281c6..07d6ab1b7cd 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -8,7 +8,8 @@ void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_142B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 12ac6f1ce88..abe737c335a 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -2,8 +2,9 @@ void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 9b39c605e35..895cf690bd2 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -12,7 +12,7 @@ void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_0_3B; break; case 28: type = LLM_TYPE_21B_A3B; break; case 54: type = LLM_TYPE_300B_A47B; break; diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index ddf13c3028f..0948d7de656 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -3,7 +3,7 @@ void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 12) { + if (hparams.n_layer() == 12) { type = LLM_TYPE_SMALL; // 0.2B } } diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 76d91982fc5..5aed9379400 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -20,13 +20,12 @@ void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_30B_A3B; break; - case 48: - case 49: type = LLM_TYPE_235B_A22B; break; + case 48: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -50,9 +49,9 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -70,7 +69,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end - if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers)) { + if (i < (int) hparams.n_layer_dense_lead || (i >= n_layer)) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); @@ -95,7 +94,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); @@ -130,8 +129,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // use RoPE for SWA layers @@ -170,7 +168,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index c7e9960d718..676fb37b5a6 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -3,7 +3,7 @@ void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 499e22dde81..863268abcef 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,7 +1,7 @@ #include "models.h" void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { - if (hparams.n_layer == 64) { // 32B + if (hparams.n_layer() == 64) { // 32B hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; uint32_t swa_period = 4; @@ -15,8 +15,11 @@ void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - switch (hparams.n_layer) { + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_1_2B; break; case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -37,22 +40,38 @@ void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { + const bool is_nextn = i >= n_layer; + int flags = 0; + if (is_nextn) { + // NextN/MTP layers are preserved in GGUF but are not executed yet. + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + if (!is_nextn) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + } } } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 94b65a3c7c9..d6ef2d51986 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -11,9 +11,9 @@ void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_0_5B; break; case 24: diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index ad546ef2db5..b2ad90b3272 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -3,7 +3,7 @@ void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 60: type = LLM_TYPE_40B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 4e07f5f2bda..80ed3b1a460 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -21,7 +21,7 @@ void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_0_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1519682fdf6..651cd7e64de 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -3,7 +3,7 @@ void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_2B; break; case 28: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index ae3f9ffb530..2fbfb15a94a 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -16,7 +16,7 @@ void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_2B; break; case 42: type = LLM_TYPE_9B; break; case 46: type = LLM_TYPE_27B; break; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index 63a2b380e71..690194529e3 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -17,7 +17,7 @@ void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_8B; break; // Rnj-1 diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 6ec3a006081..83eb8250aa9 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -6,14 +6,14 @@ void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(swa_period); - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_E2B; break; case 35: type = LLM_TYPE_E4B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp new file mode 100644 index 00000000000..5b7a25a5aba --- /dev/null +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -0,0 +1,200 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl"); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index 4f9d8b18bc7..6f7fcd645cb 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -2,12 +2,12 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); uint32_t n_kv_shared_layers = 0; ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); - hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.n_layer_kv_from_start = hparams.n_layer_all - (int32_t)n_kv_shared_layers; hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); @@ -19,7 +19,7 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_26B_A4B; break; case 35: type = LLM_TYPE_E2B; break; case 42: type = LLM_TYPE_E4B; break; @@ -142,6 +142,33 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } +// TODO @ngxson : maybe improve this in the future +class llm_graph_input_logits_bias : public llm_graph_input_i { +public: + llm_graph_input_logits_bias(const llama_vocab & vocab) { + arr.resize(vocab.n_tokens(), 0.0f); + for (llama_token id : vocab.get_suppress_tokens()) { + if (0 <= id && id < (int32_t)vocab.n_tokens()) { + arr[id] = -INFINITY; + } + } + } + virtual ~llm_graph_input_logits_bias() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override { + const int64_t n_vocab = arr.size(); + ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias)); + } + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * logits_bias = nullptr; // F32 [n_vocab] + + std::vector<float> arr; +}; + llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -245,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing - if (il == n_layer - 1 && inp_out_ids) { + // keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token) + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -345,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] // TODO @ngxson : improve this - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); } @@ -376,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para model.output_norm, nullptr, LLM_NORM_RMS, -1); + // Expose the post-output-norm hidden state (the LM-head input feature) so that + // MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the + // recurrent h input. This matches the reference (transformers/vLLM/SGLang), + // which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cb(cur, "result_norm", -1); res->t_embd = cur; @@ -388,6 +427,16 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); } + // apply logits bias if needed (e.g. for gemma4_unified patch) + // this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint) + // TODO: maybe handle this inside the sampling system in the future + if (!model.vocab.get_suppress_tokens().empty()) { + auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab); + inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size()); + cur = ggml_add(ctx0, cur, inp_bias->logits_bias); + res->add_input(std::move(inp_bias)); + } + cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp index af2b55ef563..11d91312def 100644 --- a/examples/talk-llama/models/glm-dsa.cpp +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -33,13 +33,10 @@ void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 79: type = LLM_TYPE_744B_A40B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -76,9 +73,9 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; @@ -135,8 +132,8 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); } - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + // NextN/MTP tensors (preserved but unused) - conditionally load for last n_layer_nextn + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 27654b8cba3..d60e47ddf0c 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -20,16 +20,13 @@ void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + switch (hparams.n_layer()) { + case 46: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + case 92: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 default: type = LLM_TYPE_UNKNOWN; } } @@ -54,9 +51,9 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { // Load ALL tensors including NextN layer to satisfy total tensor count // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -116,7 +113,7 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -161,8 +158,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -211,7 +207,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index 7c242fed298..b4326c5f210 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -5,13 +5,10 @@ void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); // NextN/MTP parameters (GLM-OCR) - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; @@ -32,9 +29,9 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -55,7 +52,7 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -100,8 +97,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -140,7 +136,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index e2dcc8b1521..45afbccc121 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,8 @@ void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_SMALL; break; case 24: type = LLM_TYPE_MEDIUM; break; case 36: type = LLM_TYPE_LARGE; break; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 443e35addf2..ed5e8c50da2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -3,7 +3,8 @@ void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 6: switch (hparams.n_ff()) { case 512: type = LLM_TYPE_14M; break; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 27f6706ea10..eb23095aece 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -19,8 +19,8 @@ void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { hparams.rope_finetuned = rope_finetuned; // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -71,7 +71,7 @@ void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -158,7 +158,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp index 0d89bc1f340..115263c418f 100644 --- a/examples/talk-llama/models/granite-moe.cpp +++ b/examples/talk-llama/models/granite-moe.cpp @@ -12,7 +12,7 @@ void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index cda4aa231fa..4a75c5ff3cc 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include <sstream> + void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -7,12 +9,33 @@ void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + // Granite4 Vision uses array deepstack_mapping + ml.get_arr(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr, false); + + // Count the unique deepstack input indices + std::unordered_set<uint32_t> unique_deepstack_idxs; + for (const auto val : hparams.deepstack_mapping_arr) { + if (val >= 0) { + unique_deepstack_idxs.insert(val); + } + } + hparams.n_deepstack_layers = unique_deepstack_idxs.size(); + + // Ensure all values are valid (avoid overflow attacks) + for (const auto val : unique_deepstack_idxs) { + if (val > hparams.n_deepstack_layers) { + std::stringstream ss; + ss << "Invalid deepstack index: " << val << " > " << hparams.n_deepstack_layers; + throw std::runtime_error(ss.str()); + } + } + // Granite uses rope_finetuned as a switch for rope, so default to true bool rope_finetuned = true; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes @@ -112,6 +135,20 @@ llama_model_granite::graph::graph( ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + + // Granite Vision 4.1 deepstack: inject the projector stream that + // targets decoder layer `il` before the decoder runs. + // NOTE: skip the first deepstack layer since that's inpL + const auto & deepstack_emb_idx = hparams.deepstack_mapping_arr[il]; + if (il > 0 && deepstack_emb_idx >= 0) { + ggml_tensor * ds = ggml_view_2d(ctx0, + res->t_inp_embd, n_embd, n_tokens, + res->t_inp_embd->nb[1], + deepstack_emb_idx * n_embd * sizeof(float)); + inpL = ggml_add(ctx0, inpL, ds); + cb(inpL, "deepstack_in", il); + } + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 7c46ec1c0f2..42f38af6724 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -26,7 +26,7 @@ void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_314B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 1cab75adc7f..643a448e59a 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -7,7 +7,7 @@ void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index deb3c9671f3..4d55f5e7f31 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -5,7 +5,7 @@ void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_A13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f9ee37a24b6..f6cfdfb9458 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -2,7 +2,8 @@ void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 2ba162605f1..415103ce23a 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -4,7 +4,7 @@ void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_3B; break; case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 8966131441c..8610fcc9f82 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -3,7 +3,7 @@ void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; case 68: type = LLM_TYPE_70B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index 84ea63c3136..dba160b014f 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -8,11 +8,11 @@ void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp index 4f8866ece4d..86ff1c84d1a 100644 --- a/examples/talk-llama/models/jina-bert-v2.cpp +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -4,7 +4,7 @@ void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); hparams.f_max_alibi_bias = 8.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp index e0527529f56..1c974a6f16c 100644 --- a/examples/talk-llama/models/jina-bert-v3.cpp +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -3,7 +3,7 @@ void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_558M; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index ecffb105496..367f6990d1f 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -14,8 +14,8 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent } // MoE parameters - Kimi uses moe_intermediate_size = 1024 @@ -25,7 +25,7 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B default: type = LLM_TYPE_UNKNOWN; } @@ -53,7 +53,7 @@ void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); @@ -285,7 +285,7 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_build_forward_expand(gf, cur); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py const auto * mctx_cur = inp_rs->mctx; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 29081344b24..97da8a6abb8 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -5,10 +5,13 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - hparams.n_layer_dense_lead = hparams.n_layer; + + hparams.n_layer_dense_lead = hparams.n_layer(); + switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; @@ -16,10 +19,11 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_swa_impl[il] = !hparams.is_recr_impl[il]; } } } @@ -59,7 +63,7 @@ void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); @@ -235,8 +239,8 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); - cur = hparams.is_recurrent(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : - build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); + cur = hparams.is_recr(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp index 12a66c05c7d..490f5c223eb 100644 --- a/examples/talk-llama/models/lfm2moe.cpp +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -9,11 +9,11 @@ void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_8B_A1B; break; case 40: type = LLM_TYPE_24B_A2B; break; default: type = LLM_TYPE_UNKNOWN; @@ -55,7 +55,7 @@ void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 9722dde9f17..2ae89386447 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -2,11 +2,12 @@ void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention hparams.causal_attn = false; - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 58b2c466e17..87d4259f9a7 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -2,14 +2,16 @@ void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index cef66d054b0..c0ec7e0a9ad 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -7,13 +7,13 @@ void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8x7B; break; case 56: type = LLM_TYPE_8x22B; break; default: type = LLM_TYPE_UNKNOWN; } } else { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B case 22: type = LLM_TYPE_1B; break; case 26: type = LLM_TYPE_3B; break; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 0ff5376d571..7194c72a585 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -8,14 +8,15 @@ void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa == 0) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + hparams.n_no_rope_layer_step = hparams.n_layer(); // always use rope } else { hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.n_swa = 8192; hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 84cfe399027..ae56a26a1f6 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -2,7 +2,8 @@ void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 887a1fa509a..0d94e98281c 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -9,7 +9,7 @@ void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp index 3277ca53ec4..c5951cf0f7f 100644 --- a/examples/talk-llama/models/mamba2.cpp +++ b/examples/talk-llama/models/mamba2.cpp @@ -9,7 +9,7 @@ void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mellum.cpp b/examples/talk-llama/models/mellum.cpp new file mode 100644 index 00000000000..28823018bc0 --- /dev/null +++ b/examples/talk-llama/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } + return std::make_unique<graph<false>>(*this, params); +} + +template <bool iswa> +llama_model_mellum::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph<false>; +template struct llama_model_mellum::graph<true>; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index d0295ec116f..88989160570 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -8,18 +8,18 @@ void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); float value_scale = 0.0f; if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { hparams.f_attn_value_scale = value_scale; } - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_310B_A15B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -34,16 +34,14 @@ void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - const uint32_t n_nextn = hparams.nextn_predict_layers; - - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support - const bool is_nextn = (n_nextn > 0) && (static_cast<uint32_t>(i) >= n_layer - n_nextn); + const bool is_nextn = i >= n_layer; const int skip = is_nextn ? TENSOR_SKIP : 0; create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); @@ -92,10 +90,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param const float v_scale = hparams.f_attn_value_scale; - // The last hparams.nextn_predict_layers blocks are MTP heads, currently inactive - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; uint32_t n_head_l = hparams.n_head(il); @@ -173,7 +168,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param } } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp index 966d3af615c..fc3e5b171d5 100644 --- a/examples/talk-llama/models/minicpm.cpp +++ b/examples/talk-llama/models/minicpm.cpp @@ -3,7 +3,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // Backward-compatible defaults for older MiniCPM GGUFs hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer())); hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -16,7 +16,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // MiniCPM uses rope by default, unlike Granite which uses it as a switch hparams.rope_finetuned = true; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 1ffc54fa7c6..e011b1ff0a8 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -5,7 +5,7 @@ void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 22e291d73a3..b25435e4d97 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -5,7 +5,7 @@ void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_230B_A10B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 1ac5a95ccdc..9a8e3f9a50b 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -18,7 +18,7 @@ void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; case 34: type = LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index db228865d5d..c137e32e8fd 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -411,6 +411,18 @@ struct llama_model_stablelm : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; struct llama_model_qwen : public llama_model_base { llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} @@ -810,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1030,6 +1055,19 @@ struct llama_model_deepseek2 : public llama_model_base { }; +struct llama_model_deepseek32 : public llama_model_base { + llama_model_deepseek32(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deepseek2ocr : public llama_model_base { llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1900,5 +1938,9 @@ struct llama_model_step35 : public llama_model_base { graph(const llama_model & model, const llm_graph_params & params); }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index e9b79ffc6dc..f3e9407e012 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -14,7 +14,15 @@ void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + // Some ModernBert derivatives (e.g. IBM Granite Embedding 97m R2) use + // SiLU/SwiGLU in the FFN instead of the default GELU/GeGLU. + hparams.llm_ffn_op = LLM_FFN_GEGLU; + std::string hidden_act; + if (ml.get_key(LLM_KV_HIDDEN_ACT, hidden_act, false)) { + hparams.llm_ffn_op = llm_ffn_op_type_from_string(hidden_act, LLM_FFN_GEGLU); + } + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_47M; break; // granite-embedding-small case 22: @@ -144,7 +152,8 @@ llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + hparams.llm_ffn_op, + LLM_FFN_SEQ, il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 0229d20ed36..d094fd9f80b 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -5,7 +5,7 @@ void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_30B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index a82f9c170b4..a456269347b 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -9,8 +9,8 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { // A layer is recurrent IFF the n_head_kv value is set to 0 and // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -22,7 +22,7 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; case 88: type = LLM_TYPE_120B_A12B; break; @@ -62,7 +62,7 @@ void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { // all blocks use the attn norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -143,7 +143,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else if (hparams.n_ff(il) == 0) { diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 5d4a3b5c69e..6e2bd9a33ca 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -2,7 +2,8 @@ void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index f00d6eddfc9..4a08d7abd40 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -3,7 +3,7 @@ void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 28) { + if (hparams.n_layer() == 28) { type = LLM_TYPE_250M; } } diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp index a17abe2c269..da4b62919bb 100644 --- a/examples/talk-llama/models/nomic-bert-moe.cpp +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp index 5a8a5584457..e7fc72286a6 100644 --- a/examples/talk-llama/models/nomic-bert.cpp +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index cfcf17bcb03..9f7a2ba60ef 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -4,7 +4,7 @@ void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 22: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 7cc262f5504..cb52cdef720 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -17,7 +17,7 @@ void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 7976ae44a51..1e2baeb207f 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -2,7 +2,8 @@ void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 15b6c8c1205..3ab15d61f08 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -14,7 +14,7 @@ void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 9f76350fd4d..13120bd3236 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -3,12 +3,12 @@ void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index bcb4bbba4b1..863a2822269 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -3,7 +3,7 @@ void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_14B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 7593f879b24..90f05c088c1 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -2,7 +2,8 @@ void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 8f3ed5f7b7d..81b1ad12cc0 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -3,7 +3,7 @@ void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index f8a4a4d5aa5..716ff814cc1 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,7 @@ void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp index 4575d6139cf..c332553bc7d 100644 --- a/examples/talk-llama/models/phimoe.cpp +++ b/examples/talk-llama/models/phimoe.cpp @@ -3,7 +3,7 @@ void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_16x3_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index c7ed1211c31..246144519e4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -3,7 +3,7 @@ void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b713889fe72..b93cf48bc5c 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -11,11 +11,11 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: if (hparams.n_embd == 2048) { @@ -54,7 +54,7 @@ void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); + bool is_mamba_layer = hparams.is_recr(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -128,7 +128,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - const bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recr(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 29f3e803d68..16d0b1dcef7 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -13,7 +13,7 @@ void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index ce050919e6a..8ca325f5e2c 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -3,7 +3,8 @@ void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 00467dbad7d..1f5dff3843c 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -3,7 +3,7 @@ void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index a5147460bae..e9c2ea80a6b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -2,7 +2,8 @@ void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; case 32: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7cb03859deb..e831ed11aad 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_A2_7B; break; case 28: type = LLM_TYPE_57B_A14B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 41b97fed956..1d0d2fab362 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -2,7 +2,8 @@ void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 04ecc18fcdc..4b642cff467 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -13,21 +13,20 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -38,9 +37,7 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -69,7 +66,7 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -121,10 +118,10 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -158,8 +155,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -168,7 +164,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -176,7 +172,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -208,16 +204,15 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -490,15 +485,15 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35 MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35 MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // hparams.n_layer includes both main model layers and MTP layers. The MTP // layer is stored immediately after the main layers in model.layers[]. - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -508,28 +503,41 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -611,18 +619,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index dc24f6ed537..eb5e9a406a1 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -16,21 +16,20 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -41,9 +40,7 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -75,7 +72,7 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -144,10 +141,10 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -181,8 +178,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -191,7 +187,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -199,7 +195,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -231,16 +227,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + // post-norm hidden state feeds both the LM head and the MTP seed below + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -554,13 +550,13 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35MOE MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -571,29 +567,41 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -708,17 +716,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn= cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index a4f8e1379c9..317e668bec7 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,10 +1,10 @@ #include "models.h" void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 1d873427db5..97200a44072 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -14,15 +14,15 @@ void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_80B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -68,7 +68,7 @@ void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { // Attention layers create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -129,7 +129,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 5defd893944..724d6140d19 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -4,7 +4,8 @@ void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 64: type = LLM_TYPE_32B; break; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index 5b77df57122..7c41592f772 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index bf3949a9092..a46c358fa68 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -2,7 +2,8 @@ void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index ca8e009615e..fc276ce591b 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -2,12 +2,13 @@ void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index ba2a9dfa0db..0b5013dc758 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 566b8cdcb54..6c7db514435 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 7574b252621..67c51f5b59c 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 806cba574be..57de881a091 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -2,7 +2,8 @@ void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_36B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 4231cccc666..a8e3d957f1f 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -15,14 +15,14 @@ void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; + hparams.n_no_rope_layer_step = hparams.n_layer(); } ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; case 52: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 90e7d473eaf..c67d967b204 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -4,7 +4,7 @@ void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.n_no_rope_layer_step = 4; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 4da7f7aefcf..bf6087b8796 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -3,7 +3,7 @@ void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_12B; break; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index e131af058bc..f73a88fd4e9 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 36: type = LLM_TYPE_3B; break; case 42: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 9c207c02885..b81b469374a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_3B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_15B; break; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index 3b68e68707a..e2218c58704 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -22,24 +22,39 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer(), false); + + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } } -void llama_model_step35::load_arch_tensors(llama_model_loader &) { +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_layer) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.n_layer_nextn > 0) && (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. @@ -51,14 +66,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { n_rot_max = n_rot; } - for (int i = 0; i < n_layer; ++i) { + auto load_block_trunk = [&](int i, int flags) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -70,13 +85,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); // head-wise attention gate (Step35 self_attn.g_proj) layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense MLP (leading dense blocks) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -95,10 +110,86 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); } } std::unique_ptr<llm_graph_context> llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } return std::make_unique<graph>(*this, params); } @@ -111,6 +202,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -198,8 +290,8 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "attn_proj", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -257,6 +349,13 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cur = inpL; + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -267,3 +366,192 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); + + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 73e32741406..b0e3f062572 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -9,10 +9,10 @@ void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } - hparams.dec_n_layer = hparams.n_layer; + hparams.dec_n_layer = hparams.n_layer(); ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp index 1258eeb19b6..393e8f65bf4 100644 --- a/examples/talk-llama/models/talkie.cpp +++ b/examples/talk-llama/models/talkie.cpp @@ -4,7 +4,7 @@ void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index d6d1c7a2e5d..3135001293a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -2,7 +2,8 @@ void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; case 80: type = LLM_TYPE_65B; break; From ba573929cd31ddea3c77c5dc9caae78da8117123 Mon Sep 17 00:00:00 2001 From: Christopher Albert <albert@tugraz.at> Date: Tue, 9 Jun 2026 08:34:31 +0200 Subject: [PATCH 257/289] coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868) commit 8b92060 switched ct.convert() to mlprogram, but did not update the --quantize path. quantize_weights() from neural_network.quantization_utils only works with the legacy neuralnetwork format. Running with --quantize crashed with: Exception: MLModel of type mlProgram cannot be loaded just from the model spec object. It also needs the path to the weights file. Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when --quantize is set. This matches the original intent of nbits=16 (F16 storage) without changing the quantization scheme or model accuracy. Also fix the three boolean CLI flags (--encoder-only, --quantize, --optimize-ane) to use a _str_to_bool helper so that both --flag True and --flag False parse correctly. The type=bool form accepted "False" as True because bool("False") == True. Remove the "currently broken" label from --optimize-ane: the ANE path (WhisperANE with Conv2d attention and LayerNormANE) converts and loads correctly with both PyTorch 2.x and coremltools 9.x. --- models/convert-whisper-to-coreml.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 66827b6d420..7cf07754a89 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -8,10 +8,19 @@ from typing import Dict from typing import Optional from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase -from coremltools.models.neural_network.quantization_utils import quantize_weights from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model + +def _str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes"): + return True + if v.lower() in ("false", "0", "no"): + return False + raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'") + # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. # The Whisper implementation expects a specific behavior from # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch @@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False): inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model def convert_decoder(hparams, model, quantize=False): @@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False): ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) ], + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True) - parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) - parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) - parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) + parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False) + parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False) + parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False) args = parser.parse_args() if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: From df7638d8229a243af8a4b5a8ae557e0d74e0a0ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 9 Jun 2026 12:51:00 +0200 Subject: [PATCH 258/289] ci : pin github actions to commit sha's (#3865) --- .github/workflows/bindings-go.yml | 4 +- .github/workflows/bindings-ruby.yml | 4 +- .github/workflows/build-android.yml | 8 ++-- .github/workflows/build-clang.yml | 4 +- .github/workflows/build-coreml.yml | 2 +- .github/workflows/build-cpu.yml | 10 ++--- .github/workflows/build-freebsd.yml | 4 +- .github/workflows/build-gcc.yml | 6 +-- .github/workflows/build-macos.yml | 2 +- .github/workflows/build-quantize.yml | 2 +- .github/workflows/build-sanitize.yml | 2 +- .github/workflows/build-self-hosted.yml | 10 ++--- .github/workflows/build-sycl.yml | 4 +- .github/workflows/build-vad.yml | 2 +- .github/workflows/build-wasm.yml | 2 +- .github/workflows/build-windows.yml | 2 +- .github/workflows/deploy-examples-wasm.yml | 8 ++-- .github/workflows/docker.yml | 2 +- .github/workflows/examples.yml | 4 +- .github/workflows/release.yml | 46 +++++++++++----------- 20 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 44381a4b411..91f869e99cf 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -13,10 +13,10 @@ jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@v6 + - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: '^1.23' - - uses: actions/checkout@v6 + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 0c31701a2a3..80a243e4c98 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -25,8 +25,8 @@ jobs: run: working-directory: bindings/ruby steps: - - uses: ruby/setup-ruby@v1 + - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: ruby-version: '3.2' - - uses: actions/checkout@v6 + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml index 42673166cf3..571c35872c8 100644 --- a/.github/workflows/build-android.yml +++ b/.github/workflows/build-android.yml @@ -30,12 +30,12 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: path: whisper - name: Install Java - uses: actions/setup-java@v5 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 with: distribution: zulu java-version: 21 @@ -59,10 +59,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: set up JDK 11 - uses: actions/setup-java@v5 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 with: java-version: '11' distribution: 'temurin' diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml index 5308164cc68..20b7fec6494 100644 --- a/.github/workflows/build-clang.yml +++ b/.github/workflows/build-clang.yml @@ -48,7 +48,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV @@ -95,7 +95,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml index d383d9ae0a7..8dedd7819ed 100644 --- a/.github/workflows/build-coreml.yml +++ b/.github/workflows/build-coreml.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 9c8e0586fcb..e2b74881ea5 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -66,7 +66,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -94,7 +94,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -122,7 +122,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -150,7 +150,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml index 847ae975e30..64e78ad62f8 100644 --- a/.github/workflows/build-freebsd.yml +++ b/.github/workflows/build-freebsd.yml @@ -33,10 +33,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Build - uses: cross-platform-actions/action@v0.27.0 + uses: cross-platform-actions/action@fe0167d8082ac584754ef3ffb567fded22642c7d # v0.27.0 with: operating_system: freebsd version: '14.2' diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index b1b04c24034..3d8b5137344 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -45,7 +45,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV @@ -90,7 +90,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -128,7 +128,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 804f8bbb642..8b209e4eec8 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -44,7 +44,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml index 69ab2c34638..1c9576af7f1 100644 --- a/.github/workflows/build-quantize.yml +++ b/.github/workflows/build-quantize.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml index 9250fe81023..e517f7bade4 100644 --- a/.github/workflows/build-sanitize.yml +++ b/.github/workflows/build-sanitize.yml @@ -39,7 +39,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 3fe131b9ba5..2286b63d6e7 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -52,7 +52,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -66,7 +66,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -80,7 +80,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -94,7 +94,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -107,7 +107,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml index c76954e49cf..e5361645f1e 100644 --- a/.github/workflows/build-sycl.yml +++ b/.github/workflows/build-sycl.yml @@ -46,7 +46,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: add oneAPI to apt shell: bash @@ -105,7 +105,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: add oneAPI to apt shell: bash diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml index 3c5ebec2026..dd0efa33efe 100644 --- a/.github/workflows/build-vad.yml +++ b/.github/workflows/build-vad.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index 45c77c0be4c..c17a44ae455 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -37,7 +37,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup emsdk uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 156a57f74b6..76b7a7370ce 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -43,7 +43,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 diff --git a/.github/workflows/deploy-examples-wasm.yml b/.github/workflows/deploy-examples-wasm.yml index e7fdae77854..55df14720b1 100644 --- a/.github/workflows/deploy-examples-wasm.yml +++ b/.github/workflows/deploy-examples-wasm.yml @@ -22,10 +22,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup Pages - uses: actions/configure-pages@v5 + uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5 - name: Setup emsdk uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 @@ -88,10 +88,10 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@v4 + uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4 with: path: ./staging - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 + uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b4c455b92e9..2d95e1a697f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Check out the repo - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index eaa4fe4df61..ac811712e78 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -19,7 +19,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Dependencies run: | @@ -29,7 +29,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v6 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c3ae9de4deb..11d47546caa 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 @@ -100,7 +100,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -130,7 +130,7 @@ jobs: -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz @@ -156,10 +156,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Fetch SDL2 and set SDL2_DIR if: matrix.sdl2 == 'ON' @@ -188,32 +188,32 @@ jobs: - name: Upload SDL2.dll if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ${{ matrix.s2arc }}_SDL2.dll path: build/bin/${{ matrix.build }}/SDL2.dll - name: Upload whisper dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload ggml dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll overwrite: true - name: Upload ggml base dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_base_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-base.dll - name: Upload ggml cpu dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_cpu_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-cpu.dll @@ -225,7 +225,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-bin-${{ matrix.arch }}.zip path: whisper-bin-${{ matrix.arch }}.zip @@ -253,17 +253,17 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Install OpenBLAS and pkgconfiglite if: matrix.blas == 'ON' @@ -310,7 +310,7 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-blas-bin-${{ matrix.arch }}.zip path: whisper-blas-bin-${{ matrix.arch }}.zip @@ -332,7 +332,7 @@ jobs: sdl2_ver: 2.28.5 steps: - name: Clone repository - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Install Ninja id: install_ninja @@ -459,7 +459,7 @@ jobs: echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Install 7-Zip run: choco install 7zip -y @@ -516,7 +516,7 @@ jobs: - name: Upload binaries if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip @@ -531,7 +531,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Configure run: | @@ -573,7 +573,7 @@ jobs: - name: Upload artifacts if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip @@ -594,7 +594,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 @@ -607,7 +607,7 @@ jobs: # Downloads all the artifacts from the previous jobs - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v7 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7 with: path: ./artifact @@ -627,7 +627,7 @@ jobs: - name: Upload release id: upload_release - uses: actions/github-script@v3 + uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | From 782f1226c8d9c49a6c64d654bacfe15531913a6c Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Mon, 8 Jun 2026 10:22:44 +0200 Subject: [PATCH 259/289] cuda: reset cuda context after reading memory size (llama/23935) * cuda: reset device in get_memory function if no backend is active * also count device and host buffers * exclude hip and musa from counting and device reset * use device mutex instead of atomic * undo backend_free function move --- ggml/src/ggml-cuda/ggml-cuda.cu | 75 +++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f5293ad4cbb..e779a9be9e9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -622,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::mutex device_mutex; + int active_count = 0; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +}; + struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; @@ -639,6 +651,13 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete ctx; } @@ -791,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -1490,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1498,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } + ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0. + void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { @@ -1523,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return buffer; } @@ -3140,6 +3179,12 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete cuda_ctx; delete backend; } @@ -4871,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend device -struct ggml_backend_cuda_device_context { - int device; - std::string name; - std::string description; - std::string pci_bus_id; - int op_offload_min_batch_size; -}; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4967,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::lock_guard<std::mutex> lock(ctx->device_mutex); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); @@ -4993,6 +5035,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA + // context that permanently consumes VRAM. Reset the device to free it. + if (ctx->active_count == 0) { + CUDA_CHECK(cudaDeviceReset()); + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { @@ -5687,13 +5736,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device); + ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .iface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .device = */ dev, /* .context = */ ctx, }; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return cuda_backend; } From fbf720dc9f3570ed98bd5e43806fbc4a53428084 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Mon, 8 Jun 2026 03:40:37 -0500 Subject: [PATCH 260/289] vulkan: Use cm2 decode_vector for mul_mat_id B matrix loads (llama/23991) This allows vec4 loads of the B elements. Also increase BK to 64 when this is enabled. Neither of these alone is consistently faster, but together these give a nice speedup. In ggml-vulkan.cpp, we need to make sure the B matrix alignment and stride are multiples of 4. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 149 +++++++++++++++--- .../vulkan-shaders/mul_mm_cm2.comp | 47 +++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 25 +-- 3 files changed, 183 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index fc9bc8fe376..2dd8cd2fbd9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1976,6 +1976,9 @@ struct ggml_backend_vk_context { // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; const ggml_tensor * prealloc_y_last_tensor_used {}; + // True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback. + // If false, then it's contiguous. + bool prealloc_y_last_decode_vector_staging {}; // Track which nodes have been used since the last sync, and whether they were written to std::vector<const ggml_tensor *> unsynced_nodes_written; @@ -3652,9 +3655,10 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size }; - m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; - s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; + const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u; + l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -8110,6 +8114,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(ctx, subctx); } +// Copy/convert tensor into a caller-defined dense layout. Destination strides +// are in output elements, not bytes. +static void ggml_vk_cpy_to_strided( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, + const vk_subbuffer & in, const vk_subbuffer & out, + uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) { + VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array<uint32_t, 3> elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { switch(type) { case GGML_TYPE_Q8_1: @@ -8367,24 +8405,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -8642,24 +8684,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -9110,12 +9156,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + // B must already be, or be convertible to, the matmul B type used by this path. + const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector && + (f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) && + (src1->type == GGML_TYPE_F32 || src1->type == f16_type); + // If B is copied to prealloc_y, we can choose a 4-element-aligned row stride. + const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type; + // Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned. + const bool y_decode_vector_aligned = + (ne10 % 4 == 0) && + (y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0); + // Stage B only when decode-vector is available and direct B reads would be misaligned. + const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned; +#else + const bool y_decode_vector_staging = false; +#endif + const bool y_non_contig = y_decode_vector_staging || + (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); - // If src0 is BF16, try to use a BF16 x BF16 multiply - ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -9154,11 +9218,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = padded_n * ne10 * ne12 * ne13; + const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13; const uint64_t d_ne = ggml_nelements(dst); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; @@ -9168,13 +9232,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_q8_1 = nullptr; + auto make_y_staged_dst = [&]() { + ggml_tensor y_staged_dst = *src1; + y_staged_dst.type = f16_type; + y_staged_dst.nb[0] = ggml_type_size(f16_type); + y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride; + y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n; + y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2]; + return y_staged_dst; + }; + if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + ggml_tensor y_staged_dst; + const ggml_tensor * y_staged_dst_ptr = nullptr; + if (y_decode_vector_staging) { + y_staged_dst = make_y_staged_dst(); + y_staged_dst_ptr = &y_staged_dst; + } + + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -9292,30 +9373,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + if (y_decode_vector_staging) { + const ggml_tensor y_staged_dst = make_y_staged_dst(); + const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type); + ggml_vk_cpy_to_strided( + ctx, subctx, to_fp16_vk_1, src1, + ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), + (uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size)); + } else { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + } ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } ggml_vk_sync_buffers(ctx, subctx); uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10; + uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11; if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -9330,7 +9428,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, - ne01, ne21, ne10, ne10, ne10, ne01, + ne01, ne21, ne10, ne10, stride_b_y, ne01, stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT @@ -9488,24 +9586,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -13730,7 +13832,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -14310,6 +14414,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); ctx->prealloc_y_last_pipeline_used = {}; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); @@ -14360,6 +14466,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->sync_staging); ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; @@ -15539,6 +15647,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 250d708479b..2656fe1c3e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -69,10 +72,13 @@ layout (push_constant) uniform parameter layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];}; +#endif #if QUANT_K > 1 #include "dequant_funcs_cm2.glsl" -#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector) +#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) #define DECODEFUNCA , dequantFuncA, dequantFuncA_v #else #define DECODEFUNCA , dequantFuncA @@ -113,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i const uint row_i = blockCoords[0]; const u16vec4 row_idx = row_ids[row_i]; - B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + // The decode-vector path gives B a K-dimension tensor-layout block size of BK. + const uint k = blockCoords[1] * BK + coordInBlock[1]; +#else + const uint k = blockCoords[1]; +#endif + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k]; return ret; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + const u16vec4 row_idx = row_ids[row_i]; + const uint k = blockCoords[1] * BK + coordInBlock[1]; + const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k; + + return data_b_v4[base >> 2]; +} +#define DECODEFUNCB , decodeFuncB, decodeFuncB_v +#else +#define DECODEFUNCB , decodeFuncB +#endif + D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) { uint dr = ir * BM + r; @@ -287,6 +315,9 @@ void main() { tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); #endif +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK); +#endif // Use end_k rather than p.K as the dimension because that's what // we need to bound check against when using split_k. @@ -499,7 +530,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -507,7 +538,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -543,7 +574,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -551,7 +582,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -588,7 +619,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif @@ -600,7 +631,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d65cd12b287..8fc00362870 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -457,6 +457,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (coopmat) { base_dict["COOPMAT"] = "1"; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (coopmat2) { + base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1"; + } +#endif const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; @@ -523,11 +528,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -548,8 +553,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -579,13 +584,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) From 490e50056c2a96f1ebd0cabca43b031ab865dc44 Mon Sep 17 00:00:00 2001 From: Nikhil Jain <nikhil.jain0987@gmail.com> Date: Mon, 8 Jun 2026 08:07:15 -0700 Subject: [PATCH 261/289] Implement 2D workgroups for scale, binary, and unary ops (llama/24044) * Only run webgpu CI on my fork * Add webgpu only workflow * Implement 2d workgroups for more operations * fix * Fix type * Move back to global_invocation_id --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 32 ++++++++++++------- ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 13 +++++--- ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl | 10 +++--- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 11 ++++--- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c6cfb0bbbad..94a108dfa77 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = + ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); @@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, @@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, @@ -2673,8 +2678,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, @@ -3751,7 +3758,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * + ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 605de7aa7be..f262c4a8f6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x < params.ne) { - let src0_i = params.offset_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + src1_index(gid.x); - update(params.offset_dst + gid.x, src0_i, src1_i); +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i < params.ne) { + let src0_i = params.offset_src0 + src0_index(i); + let src1_i = params.offset_src1 + src1_index(i); + update(params.offset_dst + i, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 3b70a876d70..6c76ed69e45 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -43,12 +43,14 @@ struct Params { var<storage, read_write> src: array<f32>; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + var i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i >= params.ne) { return; } - - var i = gid.x; let i3 = i / (params.ne2 * params.ne1 * params.ne0); i = i % (params.ne2 * params.ne1 * params.ne0); let i2 = i / (params.ne1 * params.ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8e34e1c9ca0..cb342c47263 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (flat_i >= params.ne) { return; } - var i = gid.x; + var i = flat_i; let ne2 = params.ne2; #ifdef DIAG let ne1 = params.ne0; @@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { #ifdef INPLACE src[params.offset_src + src_idx] = res; #else - dst[params.offset_dst + gid.x] = res; + dst[params.offset_dst + flat_i] = res; #endif } From 15e5d401d18dae5968d98ef54241317d5b8bab33 Mon Sep 17 00:00:00 2001 From: Nikhil Jain <nikhil.jain0987@gmail.com> Date: Mon, 8 Jun 2026 08:07:31 -0700 Subject: [PATCH 262/289] Handle buffer overlap / buffer aliasing for concat operator (llama/24000) * Only run webgpu CI on my fork * Add webgpu only workflow * handle buffer overlap case for concat operator * restore build-webgpu.yml Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Run clang-format * Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Reese Levine <reeselevine1@gmail.com> --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 17 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 75 ++++++++++++------- ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl | 20 ++++- 3 files changed, 79 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index a5e7de785b4..c75a98a8dd4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash { /** Concat **/ struct ggml_webgpu_concat_pipeline_key { - int type; + int type; + bool src_overlap; - bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { + return type == other.type && src_overlap == other.src_overlap; + } }; struct ggml_webgpu_concat_pipeline_key_hash { size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -2634,6 +2638,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = {}; key.type = context.dst->type; + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -2656,11 +2661,17 @@ class ggml_webgpu_shader_lib { GGML_ABORT("Unsupported type for concat shader"); } + if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_concat, defines); - auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>(); decisions->wg_size = context.max_wg_size; + decisions->src_overlap = key.src_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; concat_pipelines[key] = pipeline; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 94a108dfa77..79d5138029d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2310,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; - std::vector<uint32_t> params = { - ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), - (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], - (uint32_t) dst->ne[3], - dim, - (uint32_t) src0->ne[dim] - }; - - std::vector<wgpu::BindGroupEntry> entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; @@ -2344,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); - auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get()); + + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + } + + std::vector<uint32_t> params = { ne, + offset_src0, + offset_src1, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t) src0->ne[dim] }; + + std::vector<wgpu::BindGroupEntry> entries = {}; + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl index a22d245d2cc..eb901bf0547 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -31,6 +31,16 @@ struct Params { #define DataType i32 #endif +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var<storage, read_write> merged_src: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; +#else @group(0) @binding(0) var<storage, read_write> src0: array<DataType>; @@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>; @group(0) @binding(3) var<uniform> params: Params; - +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { @@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { ni[1] * params.stride_src0_1 + ni[2] * params.stride_src0_2 + ni[3] * params.stride_src0_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i]; +#else dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; +#endif } else { ni[params.dim] -= params.src0_nedim; let src_i = ni[0] * params.stride_src1_0 + ni[1] * params.stride_src1_1 + ni[2] * params.stride_src1_2 + ni[3] * params.stride_src1_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i]; +#else dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; +#endif } } } From aa42b48312f28f31a84840d51bcc783380b00d03 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura <yoshimura.masashi.frbs@gmail.com> Date: Tue, 9 Jun 2026 07:19:56 +0900 Subject: [PATCH 263/289] ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (llama/24225) * ggml-webgpu: Improve prefill speeds + refactor matmul for quants * Fixes for editroconfig checker --- .../wgsl-shaders/mul_mat_decls.tmpl | 810 ++++++------------ 1 file changed, 267 insertions(+), 543 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 72991504dd0..ed4a6b13bbf 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_Q1_0 -#ifdef INIT_SRC0_SHMEM_Q4_0 +#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; +#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +#else const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +#endif const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; + let block_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - let tile_m = blck_idx / BLOCKS_K; + let tile_m = block_idx / BLOCKS_K; let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; + let block_k = block_idx % BLOCKS_K; let global_block_k = k_outer / BLOCK_SIZE + block_k; if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + +#ifdef INIT_SRC0_SHMEM_Q4_0 + let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_0 +#elif INIT_SRC0_SHMEM_Q4_1 + let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q4_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 20u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_1 - -#ifdef INIT_SRC0_SHMEM_Q5_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 22u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q5_0 + let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_0 - -#ifdef INIT_SRC0_SHMEM_Q5_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 24u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; +#elif INIT_SRC0_SHMEM_Q5_1 + let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); + let q_packed = load_u32_at_src0_aligned(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); @@ -277,236 +201,73 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_1 - -#ifdef INIT_SRC0_SHMEM_Q8_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 34u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q8_0 + let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_0 +#elif INIT_SRC0_SHMEM_Q8_1 + let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q8_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 36u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d + m; shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_1 - -#ifdef INIT_SRC0_SHMEM_Q2_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 84u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - // Use standard thread layout instead of lane/row_group - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; +#elif INIT_SRC0_SHMEM_MXFP4 + let block_byte_base = src0_idx * 17u; + let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); + let e = ldexp(1.0, i32(eu8) - 128); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } +#endif } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base + 80u); - let dmin = load_f16_at_src0(block_byte_base + 82u); - - // Decode the element at position k_in_block - let block_of_32 = k_in_block / 32u; - let pos_in_32 = k_in_block % 32u; - - let q_b_idx = (block_of_32 / 4u) * 32u; - let shift = (block_of_32 % 4u) * 2u; - let k = (pos_in_32 / 16u) * 16u; - let l = pos_in_32 % 16u; - - let is = k_in_block / 16u; - - let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); - let sc = get_byte(sc_packed, is % 4u); - - let dl = d * f16(sc & 0xFu); - let ml = dmin * f16(sc >> 4u); - - let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; } } -#endif // INIT_SRC0_SHMEM_Q2_K +#endif -#ifdef INIT_SRC0_SHMEM_Q3_K +// k-quants +#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 110u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +const NQ = 4u; - let d = load_f16_at_src0(block_byte_base + 108u); - - // Load and unpack scales - let kmask1: u32 = 0x03030303u; - let kmask2: u32 = 0x0f0f0f0fu; - - var scale_vals: array<u32, 4>; - for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); - } - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); - scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); - - // Load hmask and qs arrays - var hmask_vals: array<u32, 8>; - for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); - } - - var qs_vals: array<u32, 16>; - for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); - } - - let half = k_in_block / 128u; // 0 or 1 - let pos_in_half = k_in_block % 128u; // 0-127 - let shift_group = pos_in_half / 32u; // 0-3 - let pos_in_32 = pos_in_half % 32u; // 0-31 - let k_group = pos_in_32 / 16u; // 0 or 1 - let l = pos_in_32 % 16u; // 0-15 - - let q_b_idx = half * 32u; // 0 or 32 - let shift = shift_group * 2u; // 0, 2, 4, 6 - let k = k_group * 16u; // 0 or 16 - let is = k_in_block / 16u; // 0-15 - - // m increments every 32 elements across entire 256 element block - let m_shift = k_in_block / 32u; // 0-7 - let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 - - let sc = get_byte(scale_vals[is / 4u], is % 4u); - let dl = d * (f16(sc) - 32.0); - - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - - let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); - - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = (f16(qs_val) - f16(hm)) * dl; - shmem[elem_idx] = q_val; - } +fn store_shmem_kquants(val: vec4<f16>, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; } -#endif // INIT_SRC0_SHMEM_Q3_K - -#ifdef INIT_SRC0_SHMEM_Q4_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 144u; +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; @@ -514,224 +275,232 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let global_k = k_outer + tile_k; if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); + store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx); continue; } - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - // Map k_in_block to loop structure: - // Outer loop over 64-element groups (alternating q_b_idx) - // Inner loop over 2 shifts per group - let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 +#ifdef INIT_SRC0_SHMEM_Q2_K + let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; + let scales_byte_base = block_byte_base; + let qs_byte_base = block_byte_base + 16u; + let dm_byte_base = block_byte_base + 80u; + + let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(d_packed[0]); + let dmin = f16(d_packed[1]); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + // whole 2 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u), + ); + + let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block); + + let dl = d * f16(scale & 0xFu); + let ml = dmin * f16(scale >> 4u); + + store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); +#elif INIT_SRC0_SHMEM_Q3_K + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let hmask_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 96u; + + let d_all = load_f16_at_src0(block_byte_base + 108u); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + let hmask_block = pos_in_chunk; + let hmask_shift_phase = k_in_block / 32u; + + // low 2 bits (4 elems) + let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block); + let q_lo2_vec4 = vec4<f16>( + f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u) + ); + + // high 1 bit (4 elems) + let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk); + let q_hi1_vec4 = vec4<f16>( + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u)) + ); + + let q_vec4 = q_lo2_vec4 - q_hi1_vec4; + + let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu; + let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u; + let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); + + store_shmem_kquants(dl * q_vec4, elem_idx); +#elif INIT_SRC0_SHMEM_Q4_K + let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 16u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + // whole 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 0xFu; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q4_K - -#ifdef INIT_SRC0_SHMEM_Q5_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 176u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - - // The original loop processes elements in groups of 64 - // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] - // But u increments EVERY 32 elements (after each l loop) - let group_of_64 = k_in_block / 64u; // 0-3 - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 - - // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) - let u_shift = k_in_block / 32u; // 0-7 - let u: u32 = 1u << u_shift; + store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q5_K + let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qh_byte_base = block_byte_base + 16u; + let qs_byte_base = block_byte_base + 48u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + let qh_block = k_in_block % 32u; + let qh_shift_phase = sub_block; + + // low 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_lo4_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); + + // high 1 bit (4 elems) + let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block); + let qh_vec4 = vec4<f16>( + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u)) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - - let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); - - let qh_byte = get_byte(qh_packed, l % 4u); - - let qs_val = (q_byte >> shift) & 0xFu; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - - let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; - shmem[elem_idx] = q_val; - } -} - -#endif // INIT_SRC0_SHMEM_Q5_K - -#ifdef INIT_SRC0_SHMEM_Q6_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let half = k_in_block / 128u; - let pos_in_half = k_in_block % 128u; - let quarter = pos_in_half / 32u; - let l = pos_in_half % 32u; - - let ql_b_idx = half * 64u; - let qh_b_idx = half * 32u; - let sc_b_idx = half * 8u; - - // Load only ql13 word needed - let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); - let ql13_b = get_byte(ql13, 0u); - - // Load only ql24 word needed - let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); - let ql24_b = get_byte(ql24, 0u); - - // Load only qh word needed - let qh_flat = qh_b_idx + l; - let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); - let qh_b = get_byte(qh, 0u); - - let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); - let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); - let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); - let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); - - // Load only the scale word needed - let is = l / 16u; - let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); - let sc_val = get_byte_i32(sc, 0u); - - let d = load_f16_at_src0(block_byte_base + 208u); - - var q_val: f16; - if (quarter == 0u) { - q_val = q1; - } else if (quarter == 1u) { - q_val = q2; - } else if (quarter == 2u) { - q_val = q3; - } else { - q_val = q4; - } - - shmem[elem_idx] = d * f16(sc_val) * q_val; + store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q6_K + let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; + let ql_byte_base = block_byte_base; + let qh_byte_base = block_byte_base + 128u; + let scales_byte_base = block_byte_base + 192u; + let d_byte_base = block_byte_base + 208u; + + let d = load_f16_at_src0(d_byte_base); + + let chunk = k_in_block / 128u; + let ql_pos_in_chunk = (k_in_block % 128u) % 64u; + let qh_pos_in_chunk = (k_in_block % 128u) % 32u; + let sub_block = k_in_block / 16u; + let ql_shift_phase = (k_in_block % 128u) / 64u; + let qh_shift_phase = (k_in_block % 128u) / 32u; + + // low 4 bits (4 elems) + let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk); + let ql_lo4_vec4 = vec4<u32>( + (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu + ); + + // hi 2 bits (4 elems) + let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk); + let qh_hi2_vec4 = vec4<u32>( + ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u, + ); + + let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0); + + let scale_byte = scales_byte_base + 1u * sub_block; + let scale_word = load_u32_at_src0_aligned(scale_byte); + let scale = get_byte_i32(scale_word, scale_byte & 3u); + + store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); +#endif } } -#endif // INIT_SRC0_SHMEM_Q6_K +#endif // k-quants #ifdef INIT_SRC0_SHMEM_IQ4_NL const BLOCK_SIZE = 32u; @@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S - -#ifdef INIT_SRC0_SHMEM_MXFP4 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 17u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); - let e = ldexp(1.0, i32(eu8) - 128); - - // store NQ(16) weights - for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - - let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; - let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); - } - } - } - } -} -#endif // INIT_SRC0_SHMEM_MXFP4 From e69e5138fe7a1737d3d37ec52903bb050a09a0eb Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Mon, 8 Jun 2026 20:54:24 -0700 Subject: [PATCH 264/289] ggml-webgpu: Add clang-format job (llama/24308) * Add clang-format job * try local formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 +++++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c75a98a8dd4..6f877f15ce9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -644,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { const uint32_t offset_elems = - (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / + ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; } @@ -655,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } -inline bool ggml_webgpu_flash_attn_kv_direct( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + uint32_t kv_direct_align) { return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); } @@ -671,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = context.src3 != nullptr; - key.has_sinks = context.src4 != nullptr; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; return key; } @@ -1727,7 +1730,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 79d5138029d..538e587bbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4253,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const uint32_t q_tile = use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; - const bool kv_direct = use_subgroup_matrix ? - ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : - false; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); From 72894aa2503e3eb3fdb99592da20bb313f1a9c44 Mon Sep 17 00:00:00 2001 From: ravel7524 <58877666+ravel7524@users.noreply.github.com> Date: Tue, 9 Jun 2026 07:46:23 +0200 Subject: [PATCH 265/289] Remove case for GGML_TYPE_Q4_K in mvvq.cu (llama/23528) --- ggml/src/ggml-cuda/mmvq.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index bdfbfd2d387..fe44a58da91 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -411,7 +411,6 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_K: return 8; case GGML_TYPE_Q6_K: return 2; From 2d68a3066f601e95f189dd8468b3e9fe73ac445e Mon Sep 17 00:00:00 2001 From: Yash Raj Pandey <55940078+devYRPauli@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:24:27 -0400 Subject: [PATCH 266/289] ggml-cpu : fix rms_norm_back wrong output under in-place aliasing (llama/24305) * ggml-cpu : fix rms_norm_back wrong output under in-place aliasing * cont : clean-up comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --- ggml/src/ggml-cpu/ops.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3a1912ae91b..becac9d6ef9 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4008,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); + // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms + // note: https://github.com/ggml-org/ggml/issues/1491 + const float scale_x = (float) (-sum_xdz) / sum_eps; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms; + } } } } From 28c7ed3db7e24261742123d6b33a90b0ef681808 Mon Sep 17 00:00:00 2001 From: Pascal <admin@serveurperso.com> Date: Tue, 9 Jun 2026 11:01:37 +0200 Subject: [PATCH 267/289] ggml : add GGML_OP_COL2IM_1D (llama/24206) * cpu: add GGML_OP_COL2IM_1D Add the overlap-add (scatter-add) step of a 1D transposed convolution. A ConvTranspose1d factorizes as a GEMM followed by col2im: a weight pre-permuted to [IC, K*OC] is contracted against the [IC, T_in] input with mul_mat to produce a column matrix [K*OC, T_in], and col2im_1d scatters those columns back into the [T_out, OC] signal, with T_out = (T_in - 1)*s0 + K - 2*p0. Keeping the contraction as a plain mul_mat leaves the heavy work on the optimized (and quantizable) matmul kernels, so col2im_1d only does the cheap overlap-add. CPU uses a gather formulation parallelized over output channels, supporting F32, F16 and BF16 with an F32 accumulator. * tests: add backend coverage for GGML_OP_COL2IM_1D Add test_col2im_1d next to the conv_transpose_1d cases, covering F32, F16 and BF16 across eight geometries: the canonical kernel = 2*stride DAC upsampling shape, overlap, no overlap, cropping (p0 = 1 and p0 = stride/2), kernel < stride with zeroed gaps, kernel not a multiple of stride, and a single column unfold. Perf mode gets three real vocoder stage shapes reporting memory bandwidth. max_nmse_err relaxes to 5e-4 for F16 and BF16. * cpu: harden GGML_OP_COL2IM_1D ggml_col2im_1d validates s0, oc, p0 and input contiguity at graph build time, before the oc division, protecting every backend at once. The kernel asserts the contiguity its flat indexing assumes and its doc states the full output length including the crop term. The kernel parallelizes over the time axis: the split stays balanced down to OC = 1, where the previous channel split was single threaded. Values are bit identical on the three real vocoder chains, two out of three improve. * tests: extend the GGML_OP_COL2IM_1D grid The eval grid grows to eleven geometries: OC = 1 (mono output stage), K = 1 with stride > 1 (sparse scatter, every gap position zeroed) and a crop down to T_out = 2 where all the gather bounds act at once. * tests: add col2im_1d equivalence test tests/test-col2im-1d.cpp proves mul_mat + col2im_1d matches the native ggml_conv_transpose_1d on the CPU backend, F32 bit exact, F16 and BF16 through casts of the column matrix. test-backend-ops cannot cover this for a CPU only op since the CPU backend is its own reference there. * rpc: bump protocol patch version for GGML_OP_COL2IM_1D GGML_OP_COUNT goes from 96 to 97 with the new op, which trips the static_assert in ggml-rpc.h. Bump RPC_PROTO_PATCH_VERSION since the op is appended and no existing op code shifts. --- ggml/include/ggml-rpc.h | 4 +- ggml/include/ggml.h | 11 ++++++ ggml/src/ggml-cpu/ggml-cpu.c | 5 +++ ggml/src/ggml-cpu/ops.cpp | 72 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml.c | 41 +++++++++++++++++++- 6 files changed, 130 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 6fcf5a43393..5ad121ae57f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,10 +8,10 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 4 #define RPC_PROTO_MINOR_VERSION 0 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 #ifdef __cplusplus -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); #endif #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f6725265504..374934aacf3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -535,6 +535,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, + GGML_OP_COL2IM_1D, GGML_OP_CONV_2D, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, @@ -2007,6 +2008,16 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); + // col2im_1d: scatter-add GEMM columns back to 1D signal + // a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC) + // result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0 + GGML_API struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // columns [K*OC, T_in] + int s0, // stride + int oc, // output channels + int p0); // padding to crop from both sides + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index cd5c61a8187..af7827aec39 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1912,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_3d(params, tensor); } break; + case GGML_OP_COL2IM_1D: + { + ggml_compute_forward_col2im_1d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); @@ -2343,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONV_2D: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: + case GGML_OP_COL2IM_1D: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index becac9d6ef9..86842e55474 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6730,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; // adding size avoids negative number weirdness } +// ggml_compute_forward_col2im_1d +// +// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC] +// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs. +// Parallelized over the time axis so the split stays balanced whatever OC is. +// Supports F32, F16, BF16 input/output (same type), F32 accumulator. + +template <typename elem_t> +static void ggml_compute_forward_col2im_1d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src = dst->src[0]; // [K*OC, T_in] + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int64_t K_OC = src->ne[0]; + const int64_t T_in = src->ne[1]; + const int64_t K = K_OC / OC; + const int64_t T_out = dst->ne[0]; + + const elem_t * col_data = (const elem_t *) src->data; + elem_t * dst_data = (elem_t *) dst->data; + + const int ith = params->ith; + const int nth = params->nth; + + // Parallelize over the time axis: the split stays balanced whatever OC is, + // down to OC = 1 for mono audio, and threads read disjoint column bands + const int64_t dr = (T_out + nth - 1) / nth; + const int64_t it0 = dr * ith; + const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out; + + for (int64_t oc = 0; oc < OC; oc++) { + for (int64_t t_out = it0; t_out < it1; t_out++) { + const int64_t t_abs = t_out + p0; // absolute position in uncropped signal + // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K + int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s) + if (t_in_min < 0) t_in_min = 0; + int64_t t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) { + int64_t k = t_abs - t_in * s0; + if (k >= 0 && k < K) { + // col layout: [K*OC, T_in], element (oc*K+k, t_in) + sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]); + } + } + // dst layout: [T_out, OC], element (t_out, oc) + dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum); + } + } +} + +void ggml_compute_forward_col2im_1d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break; + case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break; + case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break; + default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type); + } +} + // ggml_compute_forward_conv_2d diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 7398e561894..a8e18c716db 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -68,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8815c67d8bc..18a5ebd2ab0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "IM2COL_3D", + "COL2IM_1D", "CONV_2D", "CONV_3D", "CONV_2D_DW", @@ -1080,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1141,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "im2col_3d(x)", + "col2im_1d(x)", "conv_2d(x)", "conv_3d(x)", "conv_2d_dw(x)", @@ -1190,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4541,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph( return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); } +// ggml_col2im_1d + +struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int s0, + int oc, + int p0) { + GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); + GGML_ASSERT(s0 > 0); + GGML_ASSERT(oc > 0); + GGML_ASSERT(p0 >= 0); + + const int64_t K_OC = a->ne[0]; + const int64_t T_in = a->ne[1]; + const int64_t K = K_OC / oc; + const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0; + + GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks + GGML_ASSERT(K > 0 && T_out > 0); + + const int64_t ne[4] = { T_out, oc, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne); + + int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_COL2IM_1D; + result->src[0] = a; + + return result; +} + // ggml_conv_transpose_1d static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { From 686bc802d1f5df3cf4a23102eca785b62877619b Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Tue, 9 Jun 2026 13:27:04 +0200 Subject: [PATCH 268/289] vulkan: add `v_dot2_f32_f16` support in matrix-matrix multiplication and Flash Attention (llama/24123) * vulkan: add support for valve fp16 dot2 extension * use macro for dot2 path choice * properly check for the feature * add dot_product abstraction to reduce preprocessor branching --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 84 ++++++++++++++++--- .../vulkan-shaders/dot_product_funcs.glsl | 27 ++++++ .../vulkan-shaders/flash_attn.comp | 5 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 12 +-- .../vulkan-shaders/vulkan-shaders-gen.cpp | 46 ++++++---- 5 files changed, 139 insertions(+), 35 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2dd8cd2fbd9..c4ea0b105ce 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -113,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { } VkPhysicalDeviceShaderBfloat16FeaturesKHR; #endif +#if !defined(VK_VALVE_shader_mixed_float_dot_product) +#define VK_VALVE_shader_mixed_float_dot_product 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000) +typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE { + VkStructureType sType; + void* pNext; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat32; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat16; + VkBool32 shaderMixedFloatDotProductBFloat16Acc; + VkBool32 shaderMixedFloatDotProductFloat8AccFloat32; +} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE; +#endif + #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } @@ -705,6 +720,8 @@ struct vk_device_struct { bool coopmat2_bf16_support {}; bool coopmat2_decode_vector; + bool dot2_f16 {}; + bool pipeline_executable_properties_support {}; size_t idx; @@ -3920,8 +3937,13 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { - if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } - else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + if (device->dot2_f16) { + if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; } + else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } } else { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; @@ -4215,7 +4237,23 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} + // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + + // bf16 scalar path promotes to f32, no dot2 variant +#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ @@ -4250,7 +4288,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4258,7 +4296,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4298,8 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -4344,8 +4380,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4390,6 +4425,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM +#undef CREATE_MM_NODOT2 } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ @@ -5453,6 +5489,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = false; device->shader_64b_indexing = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -5495,6 +5532,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { pipeline_executable_properties_support = true; } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && @@ -5802,6 +5842,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product"); + } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; if (pipeline_executable_properties_support) { @@ -5836,6 +5884,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->bf16 = false; #endif + device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && @@ -6250,6 +6300,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -6279,6 +6330,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } } @@ -6369,6 +6423,13 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -6415,9 +6476,12 @@ static void ggml_vk_print_gpu_info(size_t idx) { : coopmat_support ? "KHR_coopmat" : "none"; + bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0"; + std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl new file mode 100644 index 00000000000..c474bfe09ce --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl @@ -0,0 +1,27 @@ +#ifdef DOT2_F16 +#extension GL_EXT_spirv_intrinsics : require + +spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"], + capabilities = [6912], id = 6916) +float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc); + +ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc)))); +} + +ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc))); +} + +#else + +ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), + fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc)))); +} + +ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc)); +} + +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6ac095489b3..91fb07c93e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -21,6 +21,7 @@ #extension GL_KHR_shader_subgroup_vote : enable #include "types.glsl" +#include "dot_product_funcs.glsl" #include "flash_attn_base.glsl" #include "flash_attn_dequant.glsl" @@ -318,7 +319,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]); } } } @@ -341,7 +342,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 89346e48e06..f39410d74f0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -29,6 +29,7 @@ #endif #include "types.glsl" +#include "dot_product_funcs.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -329,15 +330,8 @@ void main() { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; - #if defined(DATA_A_F32) || defined(DATA_A_F16) - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); - #else - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); - #endif + sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x); + sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8fc00362870..7bcb1460814 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 - if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { + // disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability) + if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) { cmd.push_back("-O"); } @@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s generate_dep_file = false; } -void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + std::string dot2_sfx = dot2 ? "_dot2" : ""; std::map<std::string, std::string> base_dict; std::string shader_name = "matmul"; @@ -463,6 +465,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } #endif + if (dot2) { + base_dict["DOT2_F16"] = "1"; + } + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { @@ -528,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -553,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + if (!dot2) { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } } } @@ -584,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - // Integer dot mmq performs better with f32 accumulators - if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { + // Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2) + if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -613,6 +621,10 @@ void process_shaders() { matmul_shaders(true, matmul_id_type, false, false, false); matmul_shaders(true, matmul_id_type, false, false, true); + // dot2 variants (scalar fp16 only) + matmul_shaders(true, matmul_id_type, false, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, true, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc @@ -660,6 +672,12 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + + if (fp16) { + string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc); + } + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); From dc794303d86cfc650f41e2545ae8cf19a7dc5548 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Tue, 9 Jun 2026 06:27:38 -0500 Subject: [PATCH 269/289] vulkan: reduce iq1 shared memory usage for mul_mm (llama/24287) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp | 1 + ggml/src/ggml-vulkan/vulkan-shaders/types.glsl | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c4ea0b105ce..22405f234de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3394,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048 + 4*2048; + // Regular matmul uses the compact uint16_t IQ1 grid; the expanded + // uint32_t grid is only enabled for the q8_1/int-dot vector path. + lut_size = 2*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 6fe3e2dc043..fd84c3c91d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -4,6 +4,7 @@ #extension GL_EXT_integer_dot_product : require #define MMQ +#define NEEDS_IQ1S_GRID_GPU #define B_TYPE block_q8_1_x4 #include "mul_mat_vec_base.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index f84d6f87334..8c6b20c6889 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -598,9 +598,10 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +#if defined(NEEDS_IQ1S_GRID_GPU) // Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit // and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F -// and 0xF0F0F0F0). +// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path. const uint32_t[2048] iq1s_grid_gpu_const = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, @@ -859,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = { 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; +#endif shared uint16_t iq1s_grid[2048]; +#if defined(NEEDS_IQ1S_GRID_GPU) shared uint32_t iq1s_grid_gpu[2048]; +#endif #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -875,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } +#if defined(NEEDS_IQ1S_GRID_GPU) [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { uint idx = i + gl_LocalInvocationIndex.x; if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; } } +#endif barrier(); } #endif From ef85b26d9f0bfeba3548ea6ceb213a5191ef4c11 Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Wed, 10 Jun 2026 14:27:08 +0200 Subject: [PATCH 270/289] CUDA: Fix ssm_scan_f32 data-races (llama/24360) * Add missing syncthreads before resuing cub_temp_storage __syncthreads() is required before being allowed to resue TempStorage smem: https://nvidia.github.io/cccl/unstable/cub/api/classcub_1_1BlockLoad.html#_CPPv4I0EN3cub9BlockLoad4LoadEv20RandomAccessIteratorRA14ItemsPerThread_1Ti * Add one more missing __syncthreads Could also double-buffer, but alternative is to simply ensure all threads have read smem* before writing to it again in the next loop iteration * Remove unused smem from ssm_scan_f32 --- ggml/src/ggml-cuda/ssm-scan.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2e3f97c7284..3022249c77d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1) __shared__ CubTempStorage cub_temp_storage; BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + __syncthreads(); BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); #else const int stride_s0 = src0_nb2 / sizeof(float); @@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1) regs0[n] = state; } y_block[i * stride_y + threadIdx.x] = sumf; + __syncthreads(); } #ifdef USE_CUB @@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); - const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); switch (n_tok) { case 1: From 1a1900f90c165a078d66b4958db46fe82d14ff27 Mon Sep 17 00:00:00 2001 From: Gaurav Garg <gaugarg@nvidia.com> Date: Wed, 10 Jun 2026 23:21:16 +0530 Subject: [PATCH 271/289] Remove padding and multiple D2D copies for MTP (llama/24086) * Make ggml_gated_delta_net take only the initial recurrent state (D, 1, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1]. Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy * Make GDN changes in all backends. Address review comments. * Fix CI build errors --- ggml/include/ggml.h | 17 +++++++---- ggml/src/ggml-backend-meta.cpp | 4 +-- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 17 +++++------ ggml/src/ggml-cuda/gated_delta_net.cu | 16 +++++----- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 ++-- .../ggml-hexagon/htp/gated-delta-net-ops.c | 29 ++++++++++--------- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +-- ggml/src/ggml-metal/ggml-metal.metal | 11 ++++--- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- .../ggml-opencl/kernels/gated_delta_net.cl | 8 +++-- ggml/src/ggml-sycl/gated_delta_net.cpp | 15 +++++----- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++--- .../vulkan-shaders/gated_delta_net.comp | 11 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/gated_delta_net.wgsl | 7 +++-- ggml/src/ggml.c | 16 ++++++---- 17 files changed, 93 insertions(+), 81 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 374934aacf3..d6807b6dd47 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2553,10 +2553,16 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): - // K == 1: output carries the final state only. - // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots + // tensor shapes (S_k == S_v, H_v % H_k == 0): + // q, k : [S_k, H_k, n_tokens, n_seqs] + // v : [S_v, H_v, n_tokens, n_seqs] + // g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA) + // beta : [1, H_v, n_tokens, n_seqs] + // state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0 + // + // the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state + // snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1 + // keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2564,7 +2570,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + int64_t K); // custom operators diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 8c44c3e44ae..0a36f099000 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -776,8 +776,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, - // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, + // so a head-aligned split on the input cache lands on axis 2 here. GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index af7827aec39..eb8341c9aec 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2948,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t K = ggml_get_op_params_i32(node, 0); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); cur = per_thread * sizeof(float) * n_tasks; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 86842e55474..74611dce7f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10624,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int64_t K = src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(K >= 1); - // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) - const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + // per-seq stride in floats (seq s starts at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[3] / sizeof(float); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; @@ -10644,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int64_t shift = n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const float * state_in_base = (const float *)src_state->data; @@ -10674,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( : state_out_base + (iv3 * H + iv1) * S_v * S_v; // copy input state into the working buffer and operate in-place - // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); @@ -10727,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token if (K > 1) { - const int64_t target_slot = t - shift; + const int64_t target_slot = n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state_o = state_out_base + target_slot * state_size_per_snap + (iv3 * H + iv1) * S_v * S_v; diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 7cfda652367..a547360eb06 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -39,9 +39,9 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; state += state_out_offset; curr_state += state_in_offset + col * S_v; @@ -143,12 +143,10 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d550841a2a5..49bd7e4331a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2538,7 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; - const int64_t K = state->ne[1]; + const int64_t K = ggml_get_op_params_i32(op, 0); if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2551,7 +2551,8 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { + // state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0. + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { return false; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 3b092d7440d..35518e6111c 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -584,7 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -618,9 +618,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; - const int64_t shift = (int64_t) n_tokens - (int64_t) K; uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -630,7 +629,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -661,7 +661,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -792,7 +793,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } if (K > 1) { - const int64_t target_slot = (int64_t) t - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; if (curr_state_o != s_out) { @@ -844,7 +846,6 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -878,8 +879,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); - const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -889,7 +889,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -920,7 +921,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -1097,7 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -1110,7 +1112,8 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + // state holds s0 only: [S_v, S_v, H, n_seqs] + if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) { return HTP_STATUS_NO_SUPPORT; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ce847dd8b6f..4f4f073cb61 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G - // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = op->src[5]->ne[1]; + // state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0. + const int K = ggml_get_op_params_i32(op, 0); const int nsg = op->src[2]->ne[0]/32; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2bd310d9450..0aea68455fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2599,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2620,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = (int)args.ne22 - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. // output state base offset: after attention scores const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; @@ -2680,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; if (K > 1) { - const int target_slot = (int)t - shift; + const int target_slot = (int)args.ne22 - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2a41215fd13..d30579b9452 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -17750,7 +17750,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { const cl_uint H_v = (cl_uint) src_v->ne[1]; const cl_uint n_tokens = (cl_uint) src_v->ne[2]; const cl_uint n_seqs = (cl_uint) src_v->ne[3]; - const cl_uint K = (cl_uint) src_state->ne[1]; + const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0); int si; switch (S_v) { diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl index d11192f5802..319c9829529 100644 --- a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -123,7 +123,8 @@ kernel void kernel_gated_delta_net( const uint iq3 = seq_id / rq3; // seq index for Q and K const uint state_size = S_V * S_V; - const uint state_base = (seq_id * K * H_v + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_base = (seq_id * H_v + head_id) * state_size; const uint q_off_base = iq3 * sq3 + iq1 * sq1; const uint v_off_base = seq_id * sv3 + head_id * sv1; const uint gb_off_base = seq_id * sb3 + head_id * sb1; @@ -143,7 +144,8 @@ kernel void kernel_gated_delta_net( } } - const int shift = (int)n_tokens - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -219,7 +221,7 @@ kernel void kernel_gated_delta_net( attn_off += S_V * H_v; if (K > 1u) { - const int target_slot = (int)t - shift; + const int target_slot = (int)n_tokens - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { #pragma unroll for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 9c2449aba0c..239e00bd7e5 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; @@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q, // Write state back to global memory if constexpr (keep_rs_t) { - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 22405f234de..387826b6d93 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11528,7 +11528,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; - const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -11537,8 +11536,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const uint32_t K = (uint32_t)src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0); const uint32_t s_off = S_v * H * n_tokens * n_seqs; @@ -17954,7 +17953,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + src_clone[2], src_clone[3], src_clone[4], src_clone[5], + ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 33c3202dbb7..0e384330b9b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -102,8 +102,8 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. - const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_in_base = (seq_id * H + head_id) * state_size; // output state layout per slot: same per-(seq,head) offset as the single-slot case. const uint state_out_base = (seq_id * H + head_id) * state_size; const uint state_size_per_snap = state_size * H * n_seqs; @@ -113,9 +113,8 @@ void main() { s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = int(n_tokens) - int(K); + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -172,7 +171,7 @@ void main() { attn_off += S_V * H; if (K > 1u) { - const int target_slot = int(t) - shift; + const int target_slot = int(n_tokens) - 1 - int(t); if (target_slot >= 0 && target_slot < int(K)) { const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 538e587bbe5..0b605fa86ba 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1245,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; - const uint32_t K = (uint32_t) src5->ne[1]; + const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0); const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index d68520f8282..7d7b3475549 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -63,10 +63,10 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + let state_in_base = (seq_id * params.h + head_id) * state_size; let state_out_base = (seq_id * params.h + head_id) * state_size; let state_size_per_snap = state_size * params.h * params.n_seqs; - let shift = i32(params.n_tokens) - i32(params.K); var state: array<f32, S_V>; for (var i = 0u; i < S_V; i++) { @@ -128,7 +128,8 @@ fn main( attn_off += S_V * params.h; if (params.K > 1u) { - let target_slot = i32(t) - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + let target_slot = i32(params.n_tokens) - 1 - i32(t); if (target_slot >= 0 && target_slot < i32(params.K)) { let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; for (var i = 0u; i < S_V; i++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 18a5ebd2ab0..b43016c87d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6223,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + int64_t K) { GGML_ASSERT(ggml_is_contiguous_rows(q)); GGML_ASSERT(ggml_is_contiguous_rows(k)); GGML_ASSERT(ggml_is_contiguous_rows(v)); @@ -6247,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. - GGML_ASSERT(state->ne[0] == S_v * S_v * H); - GGML_ASSERT(state->ne[2] == n_seqs); - GGML_ASSERT(state->ne[3] == 1); - const int64_t K = state->ne[1]; + // state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param. + GGML_ASSERT(state->ne[0] == S_v); + GGML_ASSERT(state->ne[1] == S_v); + GGML_ASSERT(state->ne[2] == H); + GGML_ASSERT(state->ne[3] == n_seqs); + GGML_ASSERT(K >= 1); const int64_t state_rows = K * S_v * n_seqs; const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + ggml_set_op_params_i32(result, 0, (int32_t) K); + result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; From a512e4c5c3adf375239e83410dffb31bff8b2a7f Mon Sep 17 00:00:00 2001 From: Kevin Liu <4396kevinliu@gmail.com> Date: Thu, 11 Jun 2026 09:43:04 -0400 Subject: [PATCH 272/289] vulkan: use medium matmul tile on Asahi Linux (llama/24306) * vulkan: use medium matmul tile on Asahi Linux * vulkan: switch Apple detection to Honeykrisp driver id --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 387826b6d93..47533c2ba97 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6202,6 +6202,17 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } + // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. + // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. + if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + } + device->mul_mat_l_int[i] = device->mul_mat_l[i]; device->mul_mat_m_int[i] = device->mul_mat_m[i]; device->mul_mat_s_int[i] = device->mul_mat_s[i]; From 6870cfd616bd0734c3b9ebe4dbf8010e34fdeb7e Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Thu, 11 Jun 2026 21:46:25 +0800 Subject: [PATCH 273/289] vulkan: add fast path for contiguous buffer transfers (llama/23973) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 47533c2ba97..5f372404521 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7615,8 +7615,12 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy((uint8_t *)dst->ptr + offset, src, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); @@ -7735,8 +7739,12 @@ static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, si if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy(dst, (const uint8_t *) src->ptr + offset, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(src->device->mutex); From b04008fcec0ac334d38ec754809bd7c2f8cc1f3d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 11 Jun 2026 19:32:38 +0300 Subject: [PATCH 274/289] ggml : bump version to 0.15.0 (ggml/1539) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8f7cb8cdfd2..cd0e4fef978 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 14) +set(GGML_VERSION_MINOR 15) set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") From afd559279c1a6fd484632b86ee6eee5d70ce04a2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Thu, 11 Jun 2026 13:22:17 -0500 Subject: [PATCH 275/289] vulkan: ifdef eMesaHoneykrisp (build fix) (llama/24479) Fixes build/CI after #24306. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5f372404521..1b1150e7731 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6202,6 +6202,7 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } +#if VK_HEADER_VERSION >= 287 // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { @@ -6212,6 +6213,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = false; } +#endif device->mul_mat_l_int[i] = device->mul_mat_l[i]; device->mul_mat_m_int[i] = device->mul_mat_m[i]; From 2dcfd49d59810ab7fa0672e17b96968721ed6a27 Mon Sep 17 00:00:00 2001 From: shaofeiqi <shaoqi@qti.qualcomm.com> Date: Thu, 11 Jun 2026 21:43:09 -0700 Subject: [PATCH 276/289] opencl: add q5_0/q5_1 gemm and gemv kernels for Adreno (llama/24319) * opencl: add q5_0 adreno support * opencl: add q5_1 adreno support * opencl: cosmetic fix --------- Co-authored-by: Li He <lih@qti.qualcomm.com> --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 729 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 114 +++ .../kernels/gemm_noshuffle_q5_0_f32.cl | 131 ++++ .../kernels/gemm_noshuffle_q5_1_f32.cl | 134 ++++ .../kernels/gemv_noshuffle_q5_0_f32.cl | 291 +++++++ .../kernels/gemv_noshuffle_q5_1_f32.cl | 294 +++++++ 7 files changed, 1642 insertions(+), 55 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index cd15d573238..82ce61d72c6 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -142,6 +142,10 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_0_f32 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_q5_0_f32 + gemm_noshuffle_q5_0_f32 + gemv_noshuffle_q5_1_f32 + gemm_noshuffle_q5_1_f32 gemv_noshuffle_iq4_nl_f32 gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_q8_0_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d30579b9452..ca2002424df 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -593,6 +593,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q5_0_noshuffle; + cl_kernel kernel_restore_block_q5_0_noshuffle; + cl_kernel kernel_convert_block_q5_1_noshuffle; + cl_kernel kernel_restore_block_q5_1_noshuffle; cl_kernel kernel_convert_block_q4_K_noshuffle; cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; @@ -829,6 +833,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q6_K_f32; cl_kernel kernel_gemv_noshuffle_q5_k_f32; cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_q5_0_f32; + cl_kernel kernel_gemm_noshuffle_q5_0_f32; + cl_kernel kernel_gemv_noshuffle_q5_1_f32; + cl_kernel kernel_gemm_noshuffle_q5_1_f32; cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -1152,6 +1160,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); @@ -3065,6 +3077,80 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gemm_noshuffle_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemm_noshuffle_iq4_nl_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6107,15 +6193,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; - cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_noshuffle; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); - size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; cl_event evt; @@ -6124,7 +6211,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; } if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -6225,6 +6344,42 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); @@ -7299,6 +7454,48 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_int err; @@ -7362,6 +7559,54 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_trans_m; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // Transpose back: from col-major to row-major + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -12205,7 +12450,7 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t #endif } -static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12218,17 +12463,17 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int ne1 = dst->ne[1]; + const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 % 32 == 0); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); cl_context context = backend_ctx->context; cl_kernel kernel; @@ -12243,17 +12488,17 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml int K = ne00; if (ne1 == 1) { - cl_mem q_img = nullptr; + cl_mem qs_img = nullptr; cl_mem b_sub_buf = nullptr; cl_mem b_img = nullptr; - // image for q - img_fmt = { CL_R, CL_UNSIGNED_INT32}; + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; memset(&img_desc, 0, sizeof(img_desc)); img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; img_desc.image_width = M * K / 2 / 4; - img_desc.buffer = extra0_iq4_nl->q; - CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + img_desc.buffer = extra0_q5_0->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); // subbuffer for activations region.origin = offset1; @@ -12268,22 +12513,23 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml img_desc.buffer = b_sub_buf; CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + kernel = backend_ctx->kernel_gemv_noshuffle_q5_0_f32; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); size_t local_work_size[3] = {64, 4, 1}; size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qs_img)); CL_CHECK(clReleaseMemObject(b_sub_buf)); CL_CHECK(clReleaseMemObject(b_img)); } else { @@ -12291,6 +12537,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml cl_mem b_sub_buf_trans = nullptr; cl_mem b_img = nullptr; cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; // subbuffer for activations region.origin = offset1; @@ -12326,6 +12573,11 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml img_desc.buffer = b_sub_buf_trans; CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + // transpose activations int height_B = N/4; if (height_B == 0) { @@ -12346,14 +12598,14 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); // gemm - kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + kernel = backend_ctx->kernel_gemm_noshuffle_q5_0_f32; int padded_N = N + padding; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_sub_buf)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); @@ -12368,6 +12620,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); CL_CHECK(clReleaseMemObject(b_img)); CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); } #else GGML_UNUSED(backend); @@ -12377,7 +12630,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml #endif } -static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12386,34 +12639,21 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - GGML_ASSERT(src1->view_offs == 0); - GGML_ASSERT(dst->view_offs == 0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - - const int ne10 = src1->ne[0]; - const int ne12 = src1->ne[2]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; + const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 == ne10); - GGML_ASSERT((ne00 % 32) == 0); - GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); cl_context context = backend_ctx->context; cl_kernel kernel; @@ -12428,17 +12668,384 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t int K = ne00; if (ne1 == 1) { - cl_mem q_img = nullptr; + cl_mem qs_img = nullptr; cl_mem b_sub_buf = nullptr; cl_mem b_img = nullptr; - // image for q - img_fmt = { CL_R, CL_UNSIGNED_INT32}; + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; memset(&img_desc, 0, sizeof(img_desc)); img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc.image_width = M * K / 4; - img_desc.buffer = extra0_q8_0->q; - CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_1->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % 32 == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); // create a sub_buffer for B region.origin = offset1; @@ -13243,6 +13850,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q5_0 x fp32 + if (src0t == GGML_TYPE_Q5_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_0_f32_adreno(backend, src0, src1, dst); + return; + } + + // q5_1 x fp32 + if (src0t == GGML_TYPE_Q5_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_1_f32_adreno(backend, src0, src1, dst); + return; + } + // iq4_nl x fp32 if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index d07f0a1a025..226b127ab3b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -584,6 +584,60 @@ kernel void kernel_restore_block_q5_0( } } +kernel void kernel_convert_block_q5_0_noshuffle( + global struct block_q5_0 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_0_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_0/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_0_trans4_ns( __global struct block_q5_0 * src0, __global uint * dst_qs, @@ -736,6 +790,66 @@ kernel void kernel_restore_block_q5_1( } } +kernel void kernel_convert_block_q5_1_noshuffle( + global struct block_q5_1 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d, + global half * dst_m +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_1_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_1_trans4_ns( __global struct block_q5_1 * src0, __global uint * dst_qs, diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..1d6bd48005e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl @@ -0,0 +1,131 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_0_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..94b4ef6cacc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl @@ -0,0 +1,134 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_1_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + global const half * src0_m, // A mins + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + global const half * min_ptr = src0_m + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + half4 minv = vload4(0, min_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..c228f717a94 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl @@ -0,0 +1,291 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_0_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..daf1308ea4b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl @@ -0,0 +1,294 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_1 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_1_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + global half2 * src0_m, // A mins + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + __private uint4 regA; + __private half2 regS; + __private half2 regM; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_1); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + regM = src0_m[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} From 882736f8867b3b6703f9331e46b8f8da51ea5cee Mon Sep 17 00:00:00 2001 From: ZihaoMu <zmu@amd.com> Date: Fri, 12 Jun 2026 14:32:44 +0800 Subject: [PATCH 277/289] ggml: support concat for scalar types at cuda backend (llama/24011) * cuda: support concat for scalar types * Update concat.cu * fix metal ci issue --- ggml/src/ggml-cuda/concat.cu | 142 ++++++++++++++---------- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +- ggml/src/ggml-metal/ggml-metal-device.m | 11 +- 3 files changed, 101 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index adba4d522a4..8d557092b2b 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,16 +1,18 @@ #include "concat.cuh" +#include <stdint.h> + // contiguous kernels -template <int dim> -static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2) { +template <typename T, int dim> +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); const int64_t n = ne0 * ne1 * ne2; @@ -50,37 +52,37 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont } } -static void concat_f32_cuda(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int dim, - cudaStream_t stream) { +template <typename T> +static void concat_cont_cuda(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { const int64_t n = ne0 * ne1 * ne2; const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); - ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + ggml_cuda_kernel_launch(concat_cont<T, 0>, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_cont<1> - <<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 1><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) -template <int dim> +template <typename T, int dim> static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) - concat_f32_non_cont( + concat_non_cont( const char * src0, const char * src1, char * dst, @@ -107,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3){ + uint64_t nb3) { static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - const float * x; + const T * x; for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); } else { if constexpr (dim == 0) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10); } else if constexpr (dim == 1) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10); } else if constexpr (dim == 2) { - x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10); } else if constexpr (dim == 3) { - x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10); } } - float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } - -void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - cudaStream_t stream = ctx.stream(); - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - +template <typename T> +static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + concat_cont_cuda( + src0_d + i3*(src0->nb[3] / sizeof(T)), + src1_d + i3*(src1->nb[3] / sizeof(T)), + dst_d + i3*( dst->nb[3] / sizeof(T)), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } @@ -169,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t size0 = ggml_nbytes(src0); const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); auto launch_kernel = [&](auto dim) { - concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], @@ -203,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } } + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + GGML_ASSERT(ggml_blck_size(src0->type) == 1); + + switch (ggml_type_size(src0->type)) { + case 1: + concat_cuda<uint8_t>(src0, src1, dst, dim, stream); + break; + case 2: + concat_cuda<uint16_t>(src0, src1, dst, dim, stream); + break; + case 4: + concat_cuda<uint32_t>(src0, src1, dst, dim, stream); + break; + case 8: + concat_cuda<uint64_t>(src0, src1, dst, dim, stream); + break; + default: + GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e779a9be9e9..61041bdc16b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5345,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + !ggml_is_quantized(src0_type) && + ggml_blck_size(src0_type) == 1 && + (ggml_type_size(src0_type) == 1 || + ggml_type_size(src0_type) == 2 || + ggml_type_size(src0_type) == 4 || + ggml_type_size(src0_type) == 8); } break; case GGML_OP_CONV_TRANSPOSE_1D: { diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05d7f43051b..d583bd6efc0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1120,8 +1120,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: return true; + case GGML_OP_CONCAT: + { + // kernel_concat copies one float-sized value per element. + // Other scalar types need a type-generic copy kernel first. + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: From f35f47b5d242484ef405c82ac1c40ae61e8e582c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Fri, 12 Jun 2026 15:32:00 +0300 Subject: [PATCH 278/289] ggml : bump version to 0.15.1 (ggml/1541) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cd0e4fef978..249ed3da290 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 15) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0a3fa9ca17960dc2419566f2c03ff8913edfbe17 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 15 Jun 2026 09:13:43 +0300 Subject: [PATCH 279/289] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6e1bf3a1f4b..87d353ef452 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1 +3af5f5760e19a96427f5f7a93b79cbdf3d4b265b From 0ec0845110dc934911dc48e8c5beb5ad3189b3f3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 15 Jun 2026 09:15:48 +0300 Subject: [PATCH 280/289] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 90 ++--- examples/talk-llama/llama-arch.h | 11 + examples/talk-llama/llama-context.cpp | 113 +++++- examples/talk-llama/llama-context.h | 11 + examples/talk-llama/llama-cparams.h | 3 + examples/talk-llama/llama-ext.h | 16 + examples/talk-llama/llama-graph.cpp | 38 ++- examples/talk-llama/llama-graph.h | 14 +- examples/talk-llama/llama-hparams.h | 1 + examples/talk-llama/llama-model-loader.cpp | 1 + examples/talk-llama/llama-model.cpp | 19 +- examples/talk-llama/llama-model.h | 7 + examples/talk-llama/llama-vocab.cpp | 35 +- examples/talk-llama/llama-vocab.h | 8 +- examples/talk-llama/models/delta-net-base.cpp | 41 ++- examples/talk-llama/models/eagle3.cpp | 323 ++++++++++++++++++ .../talk-llama/models/gemma4-assistant.cpp | 3 + examples/talk-llama/models/gemma4.cpp | 2 + examples/talk-llama/models/llama.cpp | 2 + examples/talk-llama/models/models.h | 17 +- examples/talk-llama/models/openai-moe.cpp | 2 + examples/talk-llama/models/plamo2.cpp | 6 +- examples/talk-llama/models/qwen3.cpp | 2 + examples/talk-llama/models/qwen35.cpp | 2 +- examples/talk-llama/models/qwen3moe.cpp | 2 + 25 files changed, 672 insertions(+), 97 deletions(-) create mode 100644 examples/talk-llama/models/eagle3.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 6a5d5f8d2ac..9f93d5bc7ce 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -3,7 +3,6 @@ #include "llama-impl.h" #include <map> -#include <set> #include <vector> static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { @@ -128,6 +127,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_EAGLE3, "eagle3" }, { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, @@ -292,46 +292,51 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TARGET_LAYERS, "%s.target_layers" }, + { LLM_KV_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, - { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, - { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, - { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, - { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, - { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, + { LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, "tokenizer.ggml.normalizer.strip_accents" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -559,6 +564,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, + { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, + { LLM_TENSOR_FC, "fc" }, + { LLM_TENSOR_D2T, "d2t" }, }; // declare information about the model weight tensors: @@ -783,6 +792,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + // eagle3 + {LLM_TENSOR_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 03b1a265d67..c5245fb5891 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -141,6 +141,7 @@ enum llm_arch { LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, LLM_ARCH_MELLUM, + LLM_ARCH_EAGLE3, LLM_ARCH_UNKNOWN, }; @@ -314,6 +315,7 @@ enum llm_kv { LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, + LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -336,6 +338,10 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_TARGET_LAYERS, + LLM_KV_TARGET_HIDDEN_SIZE, + LLM_KV_NORM_BEFORE_RESIDUAL, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -566,8 +572,13 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_MASKED_EMBD_CENTROIDS, + LLM_TENSOR_MASKED_EMBD_ORDERING, + LLM_TENSOR_FC, + LLM_TENSOR_D2T, }; + enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 9a40c4366af..168dbabd766 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -71,6 +71,9 @@ llama_context::llama_context( cparams.no_perf = params.no_perf; cparams.warmup = false; + cparams.embeddings_layer_inp.resize(hparams.n_layer(), false); + embd_layer_inp.resize(hparams.n_layer()); + cparams.ctx_type = params.ctx_type; cparams.pooling_type = params.pooling_type; @@ -91,12 +94,21 @@ llama_context::llama_context( if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { if (params.ctx_other == nullptr) { // TODO: change from runtime_error to llama_exception to avoid printing error message - throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)"); + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)"); } cparams.ctx_other = params.ctx_other; } + if (model.arch == LLM_ARCH_EAGLE3) { + if (model.tok_embd == nullptr || model.output == nullptr) { + if (params.ctx_other == nullptr) { + throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)"); + } + cparams.ctx_other = params.ctx_other; + } + } + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -194,7 +206,7 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + cparams.n_outputs_max = params.n_outputs_max == 0 || llama_model_has_encoder(&model) ? cparams.n_batch : params.n_outputs_max; cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -938,6 +950,14 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) { } } +float * llama_context::get_embeddings_layer_inp(uint32_t lid) { + output_reorder(); + + GGML_ASSERT(lid < embd_layer_inp.size() && embd_layer_inp[lid].has_data()); + + return embd_layer_inp[lid].data; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1125,6 +1145,17 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) { cparams.embeddings_nextn_masked = masked; } +void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { + LLAMA_LOG_DEBUG("%s: lid = %d, enable = %d\n", __func__, lid, enable); + + GGML_ASSERT(lid < model.hparams.n_layer()); + + cparams.embeddings_layer_inp[lid] = enable; + + // note: without this reserve, the draft acceptance drops to zero. not sure why - this is unexpected + sched_need_reserve = true; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1350,7 +1381,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // eagle3/DFlash: features as encoder input, and non-draft paths fall back to model's input dim + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1925,6 +1957,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + extract_layer_inputs(res, n_tokens_prev, ubatch.n_tokens); + // extract nextn embeddings before // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. { @@ -2029,6 +2063,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; @@ -2041,9 +2076,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - size_t backend_float_count = 0; size_t backend_token_count = 0; + size_t embd_layer_inp_float_count = 0; logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; @@ -2055,6 +2090,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd_nextn.size = (size_t) n_embd_out * n_batch; } + for (bool enabled : cparams.embeddings_layer_inp) { + if (enabled) { + embd_layer_inp_float_count += (size_t) n_embd * n_batch; + } + } + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -2069,8 +2110,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_nextn.size + embd_layer_inp_float_count + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -2087,6 +2128,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.data = nullptr; embd.data = nullptr; embd_nextn.data = nullptr; + for (auto & layer_inp : embd_layer_inp) { + layer_inp = {nullptr, 0}; + } } auto * buft = ggml_backend_cpu_buffer_type(); @@ -2118,6 +2162,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; offset += embd_nextn.size * sizeof(float); + for (uint32_t il = 0; il < embd_layer_inp.size(); ++il) { + if (cparams.embeddings_layer_inp[il]) { + embd_layer_inp[il] = buffer_view<float>{(float *) (base + offset), (size_t) n_embd * n_batch}; + offset += embd_layer_inp[il].size * sizeof(float); + } else { + embd_layer_inp[il] = buffer_view<float>{nullptr, 0}; + } + } + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2164,6 +2217,34 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens) { + for (uint32_t il = 0; il < cparams.embeddings_layer_inp.size(); ++il) { + if (!cparams.embeddings_layer_inp[il]) { + continue; + } + if (!embd_layer_inp[il].has_data()) { + GGML_ABORT("output layer input buffer not allocated"); + } + ggml_tensor * t = res->get_layer_inp((int) il); + if (!t) { + GGML_ABORT("layer input tensor not found"); + } + + const size_t nbytes = ggml_nbytes(t); + const size_t nfloats = nbytes / sizeof(float); + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(nfloats % n_tokens == 0); + + const size_t row_floats = nfloats / n_tokens; + const size_t dst_offset = token_offset * row_floats; + GGML_ASSERT(dst_offset + nfloats <= embd_layer_inp[il].size); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data + dst_offset, 0, nbytes); + } +} + void llama_context::output_reorder() { const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; @@ -2190,6 +2271,16 @@ void llama_context::output_reorder() { } } + if (embd_layer_inp.size() > 0) { + for (int lid = 0; lid < (int) embd_layer_inp.size(); ++lid) { + if (embd_layer_inp[lid].size > 0) { + for (uint64_t k = 0; k < n_embd; ++k) { + std::swap(embd_layer_inp[lid].data[i0*n_embd + k], embd_layer_inp[lid].data[i1*n_embd + k]); + } + } + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -3604,6 +3695,10 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { ctx->set_embeddings_nextn(value, masked); } +void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool value) { + ctx->set_embeddings_layer_inp(lid, value); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; @@ -3624,6 +3719,12 @@ float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { return ctx->get_embeddings_nextn_ith(i); } +float * llama_get_embeddings_layer_inp(llama_context * ctx, uint32_t lid) { + ctx->synchronize(); + + return ctx->get_embeddings_layer_inp(lid); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 6f8f59a22a3..853052be2ca 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -88,6 +88,8 @@ struct llama_context { float * get_embeddings_nextn(); float * get_embeddings_nextn_ith(int32_t i); + float * get_embeddings_layer_inp(uint32_t lid); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -112,6 +114,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); + void set_embeddings_layer_inp(uint32_t lid, bool enable); void set_causal_attn(bool value); void set_warmup(bool value); @@ -226,6 +229,10 @@ struct llama_context { // map the output row index `i` to batch index int64_t output_resolve_row(int32_t i) const; + // async-copy enabled layer-input tensors (per cparams.output_layer_inp) + // from backend into host-side embd_layer_inp buffers + void extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens); + // // graph // @@ -288,6 +295,10 @@ struct llama_context { // sets llm_graph_result::t_h_nextn buffer_view<float> embd_nextn = {nullptr, 0}; + // host buffers for output layer input embeddings, per layer + // populated when cparams.output_layer_inp[il] is true + std::vector<buffer_view<float>> embd_layer_inp; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map<llama_seq_id, llama_sampler *> samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 8a35d389ef4..2b109f909c0 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include <cstdint> +#include <vector> #define LLAMA_MAX_SEQ 256 @@ -44,6 +45,8 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + std::vector<bool> embeddings_layer_inp; // [n_layer()] extract input embeddings for layer + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index bd74544129b..b744af52864 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -101,4 +101,20 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); +// Set whether the context outputs the input embeddings of a specific layer +LLAMA_API void llama_set_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid); + LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); + +// +// model/context data extraction +// + +// returns pointer to the target-model layer indices +LLAMA_API const int32_t * llama_model_target_layer_ids (const struct llama_model * model); +// returns the number of extracted layers from target model +LLAMA_API uint32_t llama_model_target_layer_ids_n(const struct llama_model * model); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index da7a9295561..7468bd9b79e 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); } - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live + if (self_kq_mask && self_kq_mask->buffer) { + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); } - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + if (self_kq_mask && self_kq_mask->buffer) { + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -895,6 +904,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + + t_layer_inp.resize(LLAMA_MAX_LAYERS); + std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -923,7 +936,7 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } -void llm_graph_result::set_outputs() { +void llm_graph_result::set_outputs(const llm_graph_params & params) { if (t_logits != nullptr) { ggml_set_output(t_logits); } @@ -936,6 +949,15 @@ void llm_graph_result::set_outputs() { if (t_h_nextn != nullptr) { ggml_set_output(t_h_nextn); } + { + const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp; + for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) { + if (embeddings_layer_inp[il]) { + GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null"); + ggml_set_output(t_layer_inp[il]); + } + } + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -1864,9 +1886,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->t_inp_embd = cur; // For Granite architecture - // NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be - // multimodal inputs that should not be scaled. - if (ubatch.token && hparams.f_embedding_scale != 0.0f) { + // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input). + // Raw embeddings are assumed to be multimodal inputs that should not be scaled. + if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) { if (!ggml_is_contiguous(cur)) { cur = ggml_cont(ctx0, cur); } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 6793846e3ea..cc5cfe51dcd 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -705,6 +705,8 @@ class llm_graph_result { ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } ggml_tensor * get_h_nextn() const { return t_h_nextn; } + ggml_tensor * get_layer_inp(int il) const { return t_layer_inp[il]; } + ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -713,7 +715,7 @@ class llm_graph_result { void reset(); void set_inputs(const llama_ubatch * ubatch); - void set_outputs(); + void set_outputs(const llm_graph_params & params); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -734,10 +736,12 @@ class llm_graph_result { ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm - std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; - std::map<llama_seq_id, ggml_tensor*> t_candidates; - std::map<llama_seq_id, ggml_tensor*> t_sampled; - std::map<llama_seq_id, ggml_tensor*> t_sampled_probs; + std::vector<ggml_tensor *> t_layer_inp; + + std::map<llama_seq_id, ggml_tensor *> t_sampled_logits; + std::map<llama_seq_id, ggml_tensor *> t_candidates; + std::map<llama_seq_id, ggml_tensor *> t_sampled; + std::map<llama_seq_id, ggml_tensor *> t_sampled_probs; std::vector<llm_graph_input_ptr> inputs; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 032944cb481..d045059a63e 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -45,6 +45,7 @@ struct llama_hparams { bool rope_finetuned; bool use_par_res; bool swin_norm; + bool norm_before_residual = false; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 0d1cf3cc33b..474cabdfc09 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -394,6 +394,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); + template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 4f12e0949ac..7281ed79f10 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -287,6 +287,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35moe(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); + case LLM_ARCH_EAGLE3: + return new llama_model_eagle3(params); case LLM_ARCH_MIMO2: return new llama_model_mimo2(params); case LLM_ARCH_KIMI_LINEAR: @@ -2238,7 +2240,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); - llm->res->set_outputs(); + llm->res->set_outputs(params); return llm->res->get_gf(); } @@ -2406,6 +2408,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: @@ -2600,8 +2603,9 @@ uint64_t llama_model_n_params(const llama_model * model) { bool llama_model_has_encoder(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_T5: return true; - case LLM_ARCH_T5ENCODER: return true; + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_EAGLE3: return true; default: return false; } } @@ -2687,3 +2691,12 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } } + +const int32_t * llama_model_target_layer_ids(const struct llama_model * model) { + const auto & v = model->target_layer_ids; + return v.empty() ? nullptr : v.data(); +} + +uint32_t llama_model_target_layer_ids_n(const struct llama_model * model) { + return (uint32_t) model->target_layer_ids.size(); +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 992c8d9c8fd..f4718f6d584 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -569,6 +569,13 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + + // unified vector to store target-model extracted layer ids in eagle3, dflash, etc. + std::vector<int32_t> target_layer_ids; + std::vector<llama_layer> layers; //Dense linear projections for SentenceTransformers models like embeddinggemma diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 9a4bed49487..8543e178dba 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -764,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text, vocab.get_normalizer_lowercase()); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_opts()); // bos token prepended already // find the longest tokens that form the words @@ -809,11 +809,14 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text, bool lowercase) { - const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + static std::vector<std::string> preprocess(const std::string & text, const llama_vocab::normalizer_options & normalizer_opts) { + std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text); + if (normalizer_opts.strip_accents) { + cpts = unicode_cpts_normalize_nfd(cpts); + } std::vector<std::string> words(1, ""); - for (const uint32_t cpt : cpts_nfd) { + for (const uint32_t cpt : cpts) { const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { @@ -828,7 +831,11 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + if (normalizer_opts.strip_accents && flags.is_accent_mark) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(normalizer_opts.lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1692,7 +1699,7 @@ struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} void tokenize(const std::string & text, std::vector<llama_token> & output) override { - const bool lowercase = vocab.get_normalizer_lowercase(); + const bool lowercase = vocab.get_normalizer_opts().lowercase; std::string segment; auto flush = [&]() { @@ -1797,7 +1804,9 @@ struct llama_vocab::impl { bool remove_extra_whitespaces = false; bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; - bool normalizer_lowercase = true; // Lowercase normalizer (tokenizer.json) + + // BertNormalizer options + llama_vocab::normalizer_options normalizer_opts; std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -2172,7 +2181,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "whitespace") { pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; - normalizer_lowercase = false; + normalizer_opts.lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2532,8 +2541,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } - // Lowercase normalizer flag (consulted by WPM / whitespace BPE) - ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + // BertNormalizer options + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_opts.lowercase, false); + normalizer_opts.strip_accents = normalizer_opts.lowercase; + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, normalizer_opts.strip_accents, false); // suppress tokens { @@ -3969,8 +3980,8 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } -bool llama_vocab::get_normalizer_lowercase() const { - return pimpl->normalizer_lowercase; +const llama_vocab::normalizer_options & llama_vocab::get_normalizer_opts() const { + return pimpl->normalizer_opts; } const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 2626ae36e33..707cd4bac4b 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -76,6 +76,12 @@ struct llama_vocab { llama_token_attr attr; }; + struct normalizer_options { + bool lowercase = true; + bool strip_accents = true; + // TODO: clean_text, handle_chinese_chars + }; + llama_vocab(); ~llama_vocab(); @@ -141,7 +147,7 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; - bool get_normalizer_lowercase () const; + const normalizer_options & get_normalizer_opts() const; const std::vector<llama_token> & get_suppress_tokens() const; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 4f4c7cac7a8..ad9ce771408 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -398,9 +398,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); + // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs]. + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -564,11 +563,8 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t D = S_v * S_v * H_v; const int64_t K = cparams.n_rs_seq + 1; - // TODO: remove pad + simplify - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); - ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); - - ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output. + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K); if (n_seq_tokens > 1) { cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); } else { @@ -587,21 +583,24 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( cb(output, "attn_output", il); const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); - for (int64_t k_i = 0; k_i < K; ++k_i) { - const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); - ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, - S_v, S_v, H_v, n_seqs, - ggml_row_size(gdn_out->type, S_v), - ggml_row_size(gdn_out->type, S_v * S_v), - ggml_row_size(gdn_out->type, S_v * S_v * H_v), - ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); - ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, - hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - ((size_t) cache_slot * mem_size + kv_head) * row_size); + // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten + const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); - } + // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) + ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, + D, n_seqs, n_written, + ggml_row_size(gdn_out->type, D), + ggml_row_size(gdn_out->type, state_size_per_snap), + ggml_row_size(gdn_out->type, attn_score_elems)); + + ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, + D, n_seqs, n_written, + ssm_states_all->nb[1], + (size_t) mem_size * row_size, + (size_t) kv_head * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); return output; } diff --git a/examples/talk-llama/models/eagle3.cpp b/examples/talk-llama/models/eagle3.cpp new file mode 100644 index 00000000000..3321b390515 --- /dev/null +++ b/examples/talk-llama/models/eagle3.cpp @@ -0,0 +1,323 @@ +#include "models.h" + +void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + if (target_layer_ids.size() != 3) { + throw std::runtime_error("EAGLE3 requires exactly 3 entries in 'extract_layers'"); + } + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + target_layer_ids[0], + target_layer_ids[1], + target_layer_ids[2]); + + uint32_t n_embd_tgt = 0; + + ml.get_key(LLM_KV_TARGET_HIDDEN_SIZE, n_embd_tgt); + LLAMA_LOG_INFO("%s: EAGLE3 n_embd_tgt = %u (draft n_embd = %u)\n", __func__, n_embd_tgt, hparams.n_embd); + + hparams.n_embd_inp_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt; + + // eagle3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_NORM_BEFORE_RESIDUAL, hparams.norm_before_residual, false); + if (hparams.norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3gnorm_before_residual = true\n", __func__); + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_eagle3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if eagle3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab + const struct ggml_tensor * d2t_meta = ml->get_tensor_meta("d2t"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), {n_embd_inp, n_embd}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, TENSOR_NOT_REQUIRED); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml->get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // eagle3 specific: hidden_norm applied to fused target features + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if eagle3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_eagle3::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique<graph<true>>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique<graph<false>>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + +template <> +ggml_tensor * llama_model_eagle3::graph<true>::build_inp_embd_enc() const { + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp()); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +// eagle3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +template <> +llama_model_eagle3::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; + + cur = build_inp_embd_enc(); + + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + // store in t_h_nextn (same as MTP) so can be read via llama_get_embeddings_nextn(ctx_dft) + ggml_set_output(cur); + res->t_h_nextn = cur; + + ggml_build_forward_expand(gf, cur); +} + +// eagle3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +template <> +llama_model_eagle3::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_layer == 1); // eagle3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // eagle3 Decoder receives: + // 1. Token embeddings (e.g.from eagle3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + auto * tok_embd = model.tok_embd; + if (model.tok_embd == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->tok_embd != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + tok_embd = model_other->tok_embd; + } + + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + ggml_tensor * inp_embd = ggml_get_rows(ctx0, tok_embd, inp->tokens); + cb(inp_embd, "inp_embd", -1); + + ggml_tensor * inp_g = inp->embd; + cb(inp_g, "inp_g_embeddings", -1); + + res->add_input(std::move(inp)); + + inpL = inp_g; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Single decoder layer (il = 0) + const int il = 0; + { + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].attn_norm_2, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.norm_before_residual ? g_norm : inpL; + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + // Add residual and update it + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Apply FFN norm to the sum + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + + inpL = cur; + } + + cur = inpL; + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_h_nextn = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + // if the draft has no own output projection, inherit the target model's lm_head + auto * output = model.output; + if (output == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->output != nullptr && "EAGLE3 decoder requires an output projection (own or from target model)"); + output = model_other->output; + } + cur = build_lora_mm(output, cur); + + if (model.d2t) { + const int64_t n_draft_vocab = cur->ne[0]; + const int64_t n_outputs = cur->ne[1]; + const int64_t n_vocab = (int64_t) model.vocab.n_tokens(); + + GGML_ASSERT(model.d2t->type == GGML_TYPE_I64); + GGML_ASSERT(model.d2t->ne[0] == n_draft_vocab); + + ggml_tensor * logits = ggml_fill(ctx0, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_vocab, n_outputs), -INFINITY); + cur = ggml_set_rows(ctx0, logits, + ggml_reshape_3d(ctx0, cur, 1, n_draft_vocab, n_outputs), + ggml_reshape_3d(ctx0, model.d2t, n_draft_vocab, 1, 1)); + cur = ggml_reshape_2d(ctx0, cur, n_vocab, n_outputs); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp index 5b7a25a5aba..6378130e79e 100644 --- a/examples/talk-llama/models/gemma4-assistant.cpp +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -39,6 +39,9 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED); + const int64_t n_embd_backbone = hparams.n_embd_inp(); nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index 6f7fcd645cb..6a96979cebd 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -210,6 +210,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para const float freq_scale_l = model.get_rope_freq_scale(cparams, il); const int n_rot_l = hparams.n_rot(il); + res->t_layer_inp[il] = inpL; + // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index c0ec7e0a9ad..4bfebc8843c 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -124,6 +124,8 @@ llama_model_llama::graph<embed>::graph(const llama_model & model, const llm_grap ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index c137e32e8fd..ee3aff07b9a 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) + // use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs]) std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -1089,6 +1089,21 @@ struct llama_model_glm_dsa : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_eagle3 : public llama_model_base { + llama_model_eagle3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool is_enc> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_inp_embd_enc() const; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + struct llama_model_mistral4 : public llama_model_deepseek2 { llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 3ab15d61f08..6d74f9c7e6e 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -75,6 +75,8 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b93cf48bc5c..0b81513c368 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } @@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % n_heads == 0); GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 1d0d2fab362..f4b2a2aebe0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -69,6 +69,8 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 4b642cff467..6783d98ec20 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -173,7 +173,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 317e668bec7..6f6df5390e3 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -78,6 +78,8 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm From db5a84bd79926f783b199f5707af42dd99b60f2e Mon Sep 17 00:00:00 2001 From: Rum Nguyen <160252724+rumitvn@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:58:09 +0700 Subject: [PATCH 281/289] cli : add --version flag (#3878) Adds a `--version` option to whisper-cli that prints the library version via `whisper_version()` and exits, plus a corresponding entry in the help output. Mirrors the existing `-h`/`--help` handling. Closes #608 --- examples/cli/cli.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 7ca563dc250..e505bf0e18d 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -151,6 +151,10 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } + if (arg == "--version") { + fprintf(stdout, "whisper.cpp version: %s\n", whisper_version()); + exit(0); + } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } @@ -234,6 +238,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " --version show version information and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); From 48f628a84833905ee4a0658ee6d4a5c915ce1997 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 12:28:23 +0200 Subject: [PATCH 282/289] release : v1.8.7 (#3881) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3932cf2845e..b2e936e7267 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.6) +project("whisper.cpp" VERSION 1.8.7) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index fe7fa74153a..19fdc70daab 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 1f2f34672ae..7c66c730c6c 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.6", + "version": "1.8.7", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 3805e602d3a3f80ca13211cb96900eae5aad4d1d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 14:33:42 +0200 Subject: [PATCH 283/289] ci : only trigger release jobs for tags (#3883) * ci : only trigger release jobs for tags This commit removes the building of the release jobs on pushed to master. The motivation for this is that it can be confusing at the momement when releasing that the push to master also triggers the release jobs but the actual release will be skipped. With this change the release job is only run when a tag is pushed which should result in a single Release github actions job and make it easier to follow. * ci : add GGML_NATIVE=OFF for ubuntu-22-gcc --- .github/workflows/build-gcc.yml | 1 + .github/workflows/release.yml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index 3d8b5137344..53c1b2d783c 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -75,6 +75,7 @@ jobs: apt update apt install -y build-essential cmake libsdl2-dev git ccache cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache make diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 11d47546caa..ef2c3083c9f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,8 +13,6 @@ on: type: string push: - branches: - - master tags: - 'v*' From 9efddafb9153e1fb22bdc3dd3057072c99165ed2 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 20:44:10 +0200 Subject: [PATCH 284/289] parakeet : add support for NVIDIA Parakeet (#3735) * parakeet : add support for NVIDIA Parakeet Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --- CMakeLists.txt | 37 + bindings/ruby/ext/extconf.rb | 2 +- cmake/parakeet-config.cmake.in | 30 + cmake/parakeet.pc.in | 10 + examples/CMakeLists.txt | 2 + examples/parakeet-cli/CMakeLists.txt | 8 + examples/parakeet-cli/README.md | 106 + examples/parakeet-cli/parakeet-cli.cpp | 243 ++ examples/parakeet-quantize/CMakeLists.txt | 7 + .../parakeet-quantize/parakeet-quantize.cpp | 230 + include/parakeet.h | 342 ++ models/convert-parakeet-to-ggml.py | 337 ++ models/for-tests-ggml-parakeet-tdt.bin | Bin 0 -> 16603 bytes models/generate-parakeet-test-model.py | 182 + models/requirements-parakeet.txt | 3 + scripts/quantize-parakeet.sh | 15 + scripts/upload-parakeet.py | 157 + src/CMakeLists.txt | 23 + src/parakeet-arch.h | 188 + src/parakeet.cpp | 3838 +++++++++++++++++ tests/CMakeLists.txt | 59 + tests/librispeech-parakeet/.gitignore | 6 + tests/librispeech-parakeet/Makefile | 15 + tests/librispeech-parakeet/README.md | 57 + tests/librispeech-parakeet/eval.mk | 39 + tests/librispeech-parakeet/eval.py | 47 + .../librispeech-parakeet/normalizers/LICENSE | 25 + .../normalizers/__init__.py | 2 + .../librispeech-parakeet/normalizers/basic.py | 80 + .../normalizers/english.json | 1741 ++++++++ .../normalizers/english.py | 550 +++ tests/parakeet-expected-diffusion-output.txt | 1 + tests/parakeet-expected-gb1-output.txt | 1 + tests/parakeet-expected-jfk-output.txt | 1 + tests/parakeet-verification.h | 110 + tests/run-tests.sh | 46 +- tests/test-parakeet-full.cpp | 101 + tests/test-parakeet.cpp | 99 + 38 files changed, 8733 insertions(+), 7 deletions(-) create mode 100644 cmake/parakeet-config.cmake.in create mode 100644 cmake/parakeet.pc.in create mode 100644 examples/parakeet-cli/CMakeLists.txt create mode 100644 examples/parakeet-cli/README.md create mode 100644 examples/parakeet-cli/parakeet-cli.cpp create mode 100644 examples/parakeet-quantize/CMakeLists.txt create mode 100644 examples/parakeet-quantize/parakeet-quantize.cpp create mode 100644 include/parakeet.h create mode 100755 models/convert-parakeet-to-ggml.py create mode 100644 models/for-tests-ggml-parakeet-tdt.bin create mode 100755 models/generate-parakeet-test-model.py create mode 100644 models/requirements-parakeet.txt create mode 100755 scripts/quantize-parakeet.sh create mode 100644 scripts/upload-parakeet.py create mode 100644 src/parakeet-arch.h create mode 100644 src/parakeet.cpp create mode 100644 tests/librispeech-parakeet/.gitignore create mode 100644 tests/librispeech-parakeet/Makefile create mode 100644 tests/librispeech-parakeet/README.md create mode 100644 tests/librispeech-parakeet/eval.mk create mode 100644 tests/librispeech-parakeet/eval.py create mode 100644 tests/librispeech-parakeet/normalizers/LICENSE create mode 100644 tests/librispeech-parakeet/normalizers/__init__.py create mode 100644 tests/librispeech-parakeet/normalizers/basic.py create mode 100644 tests/librispeech-parakeet/normalizers/english.json create mode 100644 tests/librispeech-parakeet/normalizers/english.py create mode 100644 tests/parakeet-expected-diffusion-output.txt create mode 100644 tests/parakeet-expected-gb1-output.txt create mode 100644 tests/parakeet-expected-jfk-output.txt create mode 100644 tests/parakeet-verification.h create mode 100644 tests/test-parakeet-full.cpp create mode 100644 tests/test-parakeet.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b2e936e7267..dff25f25a34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,12 +180,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) + install(TARGETS whisper LIBRARY PUBLIC_HEADER) target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -211,6 +219,35 @@ configure_file(cmake/whisper.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) + # # programs, examples and tests # diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 4b09b6ebe13..99894f1234d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -30,6 +30,6 @@ #{libs}: cmake-targets cmake-targets: #{"\t"}"#{cmake}" -S sources -B build #{options} - #{"\t"}"#{cmake}" --build build --config Release --target common whisper + #{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet EOF end diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..5a25fbb2e42 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0bb54cec489..7aedb9df683 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -107,6 +107,8 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) + add_subdirectory(parakeet-quantize) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..ccb8404f542 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,106 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model <path to cloned model> \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3-f16.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..03ddc7f8b8c --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,243 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> +#include <thread> +#include <vector> +#include <cstring> +#include <fstream> + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + bool use_gpu = true; + int32_t gpu_device = 0; + + bool print_segments = false; + bool output_txt = false; + bool no_prints = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; + std::vector<std::string> fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + bool * is_first = (bool *) user_data; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + *is_first = false; +} + +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + bool is_first = true; + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = &is_first; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + continue; + } + + printf("\n"); + + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + } + + parakeet_free(pctx); + + return 0; +} diff --git a/examples/parakeet-quantize/CMakeLists.txt b/examples/parakeet-quantize/CMakeLists.txt new file mode 100644 index 00000000000..6b46da18d27 --- /dev/null +++ b/examples/parakeet-quantize/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET parakeet-quantize) +add_executable(${TARGET} parakeet-quantize.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT}) +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-quantize/parakeet-quantize.cpp b/examples/parakeet-quantize/parakeet-quantize.cpp new file mode 100644 index 00000000000..a5d9616420f --- /dev/null +++ b/examples/parakeet-quantize/parakeet-quantize.cpp @@ -0,0 +1,230 @@ +#include "ggml.h" +#include "ggml-backend.h" + +#include "common-ggml.h" + +#include <cassert> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <string> +#include <vector> + +struct parakeet_hparams { + int32_t n_vocab = 0; + int32_t n_audio_ctx = 0; + int32_t n_audio_state = 0; + int32_t n_audio_head = 0; + int32_t n_audio_layer = 0; + int32_t n_mels = 0; + int32_t ftype = 0; + int32_t n_fft = 0; + int32_t subsampling_factor = 0; + int32_t n_subsampling_channels = 0; + int32_t n_conv_kernel = 0; + int32_t n_pred_dim = 0; + int32_t n_pred_layers = 0; + int32_t n_tdt_durations = 0; + int32_t n_max_tokens = 0; +}; + +static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__); + return false; + } + fout.write((char *) &magic, sizeof(magic)); + } + + // hparams + parakeet_hparams hparams; + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + + const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; + const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype); + fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src); + fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst); + fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + fout.write((char *) &ftype_dst, sizeof(ftype_dst)); + fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + } + + // mel filterbank + { + int32_t n_mel, n_fb; + finp.read((char *) &n_mel, sizeof(n_mel)); + fout.write((char *) &n_mel, sizeof(n_mel)); + finp.read((char *) &n_fb, sizeof(n_fb)); + fout.write((char *) &n_fb, sizeof(n_fb)); + + const size_t n = (size_t) n_mel * n_fb; + std::vector<float> buf(n); + finp.read((char *) buf.data(), n * sizeof(float)); + fout.write((char *) buf.data(), n * sizeof(float)); + } + + // window function + { + int32_t n_window; + finp.read((char *) &n_window, sizeof(n_window)); + fout.write((char *) &n_window, sizeof(n_window)); + + std::vector<float> buf(n_window); + finp.read((char *) buf.data(), n_window * sizeof(float)); + fout.write((char *) buf.data(), n_window * sizeof(float)); + } + + // TDT durations + { + std::vector<uint32_t> buf(hparams.n_tdt_durations); + finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + } + + // vocab + { + int32_t n_tokens; + finp.read((char *) &n_tokens, sizeof(n_tokens)); + fout.write((char *) &n_tokens, sizeof(n_tokens)); + + for (int i = 0; i < n_tokens; ++i) { + int32_t len; + finp.read((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + std::string token(len, '\0'); + finp.read(&token[0], len); + fout.write(&token[0], len); + } + } + + // tensors — quantize 2D weights skipping tensors that must stay F32: + // ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights. + // pos_bias_u / pos_bias_v are declared F32 in the loader. + const std::vector<std::string> to_quant = { ".*" }; + std::vector<std::string> to_skip = { + // CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights) + "encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight", + // Declared F32 in loader (pos_bias tensors) + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_u", + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_v", + }; + + // Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant + // types (block size 256) cannot quantize 640 evenly, so keep them F32. For + // other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be + // quantized normally. The loader mirrors this logic at load time. + { + const ggml_type qtype = ggml_ftype_to_ggml_type(ftype); + const int32_t blck = ggml_blck_size(qtype); + if (blck > 1 && hparams.n_pred_dim % blck != 0) { + to_skip.push_back("decoder\\.prediction\\.embed\\.weight"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*"); + to_skip.push_back("joint\\.pred\\.weight"); + to_skip.push_back("joint\\.joint_net\\.2\\.weight"); + } + } + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) { + fprintf(stderr, "%s: failed to quantize tensors\n", __func__); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // initialise F16 lookup tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + if (ftype == GGML_FTYPE_UNKNOWN) { + fprintf(stderr, "%s: invalid quantization type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + const int64_t t_start_us = ggml_time_us(); + + if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str()); + return 1; + } + + printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f); + printf("%s: output model = %s\n", argv[0], fname_out.c_str()); + + return 0; +} diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..d35aa870adb --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,342 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdbool.h> + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 128 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + int audio_ctx; // overwrite the audio context size (0 = use default) + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..2d6a6d01554 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name or \ + ('pre_encode.conv' in name and n_dims == 4) or \ + 'depthwise_conv.weight' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + 'n_conv_kernel': config['encoder']['conv_kernel_size'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_conv_kernel'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + # Pre-collect prediction LSTM input-hidden biases so they can be + # folded into the hidden-hidden bias during the main write loop. + lstm_prefix = 'decoder.prediction.dec_rnn.lstm' + pred_bias_ih = {} + for key, t in state_dict.items(): + if f'{lstm_prefix}.bias_ih_l' in key: + layer_idx = int(key.rsplit('bias_ih_l', 1)[1]) + pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # bias_ih is folded into bias_hh below; skip writing it separately + if f'{lstm_prefix}.bias_ih_l' in name: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + # For prediction LSTM weights/biases: + # Fold bias_ih into bias_hh (bias_ih already skipped above). + # Reorder gates (input, forget, cell, output) from PyTorch layout + # [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs + # (i, f, o) are contiguous. + if name.startswith(f'{lstm_prefix}.'): + if f'{lstm_prefix}.bias_hh_l' in name: + layer_idx = int(name.rsplit('bias_hh_l', 1)[1]) + data = data.astype(np.float32) + pred_bias_ih[layer_idx] + name = name.replace('bias_hh_l', 'bias_h_l') + h = data.shape[0] // 4 + data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + write_tensor(fout, name, data, use_f16=use_f16) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/for-tests-ggml-parakeet-tdt.bin b/models/for-tests-ggml-parakeet-tdt.bin new file mode 100644 index 0000000000000000000000000000000000000000..8b1dda1feba08f4439a0e5fd3e2edbb7c38cc41e GIT binary patch literal 16603 zcmbVz2|U$Zx4)qb2}z_vC_+&x%JAFkq>|D+P*Reph(eM^V=~VoN}^;|M26p9M-vSa z4NB5LsYE1Ess879pXc7^{om)k_r7<3KKpkL&RO>UuC>?xuC>?R<+6SI2mt{BpOFFr zV!T%&UKix`QM}G?9A4-Dei*L{|M8yxcmC&w^E&@?s=WS6IEBno4#XHG4m_>d1A!Ia zus-iRCcYCQ;UCgy{Mq12c{^(~Qn3SDrA08_IEZmu90gl9eIh5mZ^ZVym1Lc_Cn~?* z0!l}hvXM&}>bOc5O%A2NcN1lLYuPjUbV)IFS||pGCRC8ZRS!r^Tr;T}CC*8cyHBmu zgmLGZmH4nw3AFv9VDYB+?9nk+I7?;?a>6y>-9iGIBPHO;W^F90a3KQ&T}=BYThb?I z3&vl!kwvvS@WwHO@f3|C6H-(`E}$Gg=cLnuhuzdCd=k392&0TOf@0<-9Q(=(7VT(6 z$$dGLy`VvIM$5zJ(FSl&!2%S*F3~MzAL)vWjbQ#}J}q3FjuFz)Xk_P%O`*St<P#Zk z{@M&^<~C8|^=jmzh6?0tJx$$P-jnYhS~&A{8P)K(M$KfB&~xoK;=8V!sueFaahd#; zRHn9(d`oS(dg&G$QT~K3o1f0ces9DdMrCNYJPV^Oj-r0SU2>pyJ03e-MU|GF2jd+j zw4v4nq6K%O;pWx!+-wh|@p*K&vK2g5c#hS(TFBQu9>ie00`5{M!}Ke$%;PXo%r$i- z3x$hJ#){TL#yC&fFX)eAadXIhGdt4!Ef1P*t%G}|CB$!81=wpP;?<jzP;RyhPKmFk zoH>O!FG&qHMA%ci`IUI|moK&koFV%nE#Xz%Zzwxt#PQa7NNiUhMt2V}?g5!#@=kFj zUe0<*w_H-hLsLi55$?yqs$PXmb5o<sq5@FZ<}q6T9H0soHf)f^YbbCWPU_<hLT>tR z%H*EFHW@E;x?qS;=ZbM0%~qpPrX`qYY{!^`u3({XgA_QZ(B!Wg?9?+;FxI|}B=)|e z=U=9vXKXsQmd_yB6K_M~!$~A_;4^#ADu&2R&4hdBwh$?Y2HJ8m9Qp&Fk<8fZP?u&* zwpXNJ<HQ0?UiARRq&36q2Xm;Z-gTNhsseJ0!tvoq2b8%u84AR2({+zBD`GcIp)rqh z$!5)BhzlrVraheu!wrPFpTich-OsI1Mkoan?`=V&(K+<QekWk6v+zyVcN3{M{q)B5 zOuE_H2R>DaL-X_;U=5V%QmvT?#twAMs{Q0#-E3l18-kZtmobIKJBfv|D#zWV8#G<5 z@%^M&Dzmx*W9N>=9|>t>=Z2HOoHN47`=;P@7jf>JXbn2^XArIr7R6rgEm%IXiX}f^ z;-s1iOjsz!nSVP2_uGfz;KU9{{b<N}(-)7bp_SD3Wj^U&_#Cp&9ztd9gEaBC8Du53 zkf0M>+A1ba69x;gzSV>F*dM2#Sr^i&C51;_dPwdWUue;sNSwwMlU-#Sn3;oxb49Pr z#bAlEWG1O*_kSJExnzF}^22YCD^~N-$36mWy%Iqt=os1fsE;Hre}O5xT%acF9Z6J+ zz#``!+A}<exU^d``%c8u3q225U149k!BCMr-B1Z?3np`49(Sj?IGY>`J3&r3Q)=4$ z5=A~`lKbPHL(IDC_-N85@SJ^tGSd&Tb~lb;Ba=g$x1WbMnsQ)Y`5bm%PJn^Bmjn!7 zK*_b2M67=eXgoPVr!79tYz~=;w(AZ<W<vrhnVW-{(R}LC5f8m{Iq+%3dypJs45#Ma zLR1cB*9}X+pLxUZ)9xy8(w_tqI%ndMDqjpR6Jf}eWPDmT6)$5dS@0?r?lew-!x^Kw zeUFlf)v;K{<^2QF_u?Lm8##;{v3e|~FB}U6ZnyDmM<p;xm886EJ#5k%jrFTnLf*CG z5bNbiZj0<8u1+Cj=CDn8Rc0)0(N(0gmR=_ZzZ?a{BSE-XNPt?NVd3sZg5Qs<#*JDl zNSfAqu%2yfVmBchnk+X$psy9gHA%w4fiawSa;s^Hp#WEZeJ#XUeW1I#o0*NNnb0U- zi{JInVBXie7_ig=B^6!~9X%1K^9U#YkBYJKi7fYd+GZv#tcRYO-9b!>W`S0uB#}!r z1^W^wvig!J?01-rQXl8g+@p)h6sgtVmUxTKbsi2PipObhSvPC5aV?xw{EhcMpF_?S z7W+5IaoMvv__hgPO>;Co6ug8SbPk2oo-EKy3}=^_=0SUIB~xK*iK*HOIM<IwYGaT6 z%StJhE`u4S&9J7U5DQdpk<b}|cvPa7#2Jgwi5oxQiwqZ7`)WKE=856?&f}<fbP9To zc}5oAnn&%;E|7+0WyDP)3vzr5Xr4HWc=9;ZU&*5V*%2VP<T~jv7ok-#mq<<dH{f&* z<6IgW!3c^^Cd}1fAcxoEtmD~aP0NGIxGQJK1*>(~yYCKO%NfRX6N$vbjUTAKQUcZM z?+3l;a5!nP51)u!B9Xi9(@BZZROlu{1pJg?Nl_kMJncA$O^w0#rFF)OIU1Z{ySI^8 zk5cGztzhc%#=z0|9qgK;x>PG{5$;>rOssxwz&rNZQ1BrH1hXs1%*Yt3<NA&`j5tk) z`Rs-O=MPx%Fo$+F{K6@l5l6)4!xL=>cs6}5im&EkV>!b;rY3=7-i_x>$!a009BKTe z+(i!6R^tt$MEa!s4imDjkX`7$1gty@iPPArplvt<jT;?6$oD>6*|;A?IQIDCdp6{M znopezi(o|DYx;?l!`uuJ!r^g<XoCQfZ(ajG49lop=N{D0enVF!wv#(fHE5wX2|7NC z<FSqyh%z3>-D0yCXS6ipP5(8hnsS&-51K=Z%YCVh;3O)va3v`j>x<*Q%!UIQT5xgF z4f1`z6pX)qf@BaOj_c-2tgz83I=!8P(^UhoX-O&?9=FGr+$gp}Gly6nSj9~IJ|B;| zJA%d%2A&?6ME%vL;7x~-pmB&pOl-=~#4wbN)H+5SS3BZzM^kcUPC2TJ<-mi=RQR&k zmln#G6D^bR+*!hLsOOYGFBk`7;cEkM%KQy0lM_&2n<d(fb0boQ_u$&T6kH>Gl<*)J z;$D0{#^+Whp03V>Pr6Jz2%mqMO!pvs?q$b+Y~?@f`HyY<hXeoN$bW3-Q!^!A=Qloe z<97#F``z0d+;#W3JJ|l?z3y($y}DiwJGbxHyISQvJlXXCWxuCm|LZ5zeBp3R7g-7C zJiSPtUMqI@2*Zh|vM8^b4sls3IQN?xo(viR8Y||5xyu`pS=ESVD#by6TL2n<+(Wud zO3*T(jVw649Bp+nu%prdvfR8#yp9qYbIV}nsaMp}Uziao2qrFPm9ctm4_Ox_hNf2= zsH}57eAElXUryt>(`Vg-8Jd$gdym%96(9Pss{b^t4~?M?2TiGue*`LT9s|l|B2ZwC z82M`^=-CQ$Y~^Q=*OznYlASVizpoOUKlU8OEj3BohiZB}@FQ4S&!kp|AHkMI;dt6o z9$b&74`Diij~1`}H(+huY3K3Edo!-h)F<_`l0mfQ95s%4f*mg+D`ySgjw=@xP`;W} z_;PrSuQ`6_D~132ztMsHZr6RfGyYJ+;jx+Q{nbs_R(lV<>c5h<`@`|Z<jLHS4s#f} zH<b8r4X`cz0{Zo@1NoLqB;t?)=c`&H5%s-EG(TxmLxoikKU@Wl%{>Fop2M-cvxN#T z>!Uy1mO$lggvp&BXwSi?wBp_w(thC}y<O%3%1WcSJ!jY8o~7nk>NAEbba5Z%E^<P< zu4O}WQRa=C-}rO*a~}NHe?1pIN;#eppxbkR7Weq0_TruF{{@wqe;TL7G1&9O8~04u zi84Q>U{ljp8d;?dcM@zXjnB*ix5EKg^lM?I3R=;XXOnTHN)3<jPY1EWA}ZJ9hx@)= zXFqO@f|Pz4?z3pb#iqOPy=VoHf+*7Nx&ZEUN8^{Jk=VHC4UxI%%4i%I2@j@m=(8OW zWZ=Cjxi(=DC?qb2eyci!2hv>rT>f2DX6o`$(KGacEux8VGu;p6rMvL2sQeY*|ISus z{ejAjhTDvjg)793eG6lz4TH>8l<H?}!&#?|O$-f_D0cedePv;eqfZ-1I7LHKr89=T zwuT5g39RB@54n?ncNVjTW+B#>3I&;J5bMUkl97|(ue12yTD<-rvoJW6Pi&X9VbHuB zCTXh_KCDm2>kAv<%8KXE^F0sN<@k}suX3UIz(SL;*(=D4uc7E}dWV=;c`?h`^Ej$_ z6+3TnI(ta|EogUagUmo<lkr!NkP!;{OhL^G`YkL3*XNC=O9cR}Mx4RjFNC0RO$~W; zsSV{!f3S<+9)WXlL6ExKANJZ$Wd1L@)gMBoZS8tIR}z9!dkV>~8CQq!RpD6*zx{7@ zYw!mu3#v0f?QR(KI*QUsrm5ItmrkF#Z6@!#meZDQSyVpYOST9LVt0)S;=WY!t-hA} zeDx-duR=-PljtG4`ge6}Ff<Fs;VYFDzXHQC4i|Wda`_|I9-6~HJH}mh-VW{_Jjdm_ z+ug<1(ZOMxt>bQYFFW^bdVkEJXJr=-dUcr?oBcHT`n;QsbxMXS(o?xlN^cYW=heWM ztod(B=C7fN8>3KQ-&J~y<>2$)ZFpZIlRhpjfy2`~z(^pM^7lTSzsObLz5ml!{l}>P zdBD4Nx;ogo|2c5sr@fGIMTY6!{}MLIA}camgzTSKL9E;@I8B`!V2QLo!l?bAWLidp zx!>uylRIE`<}VZ9v<PCR=3sJWZ7xxbe@c#b^`l>gHk@%dh^D8mU>avPYy2yRNe$LR zU%D0??&e^Z^?kasTn;T72)nF0f}{)i;t3~1uov}ZET(=&m7?$1{CIJt>B6nFv-Sb& zW&Ea6>zgCdPI8BWxe=%>oPgq=L$UQ^EZJk73df!_)20b3+&zj-_-5WZTwN`ND;4S~ zBdW~IIh9JUtvm+%rFHNH*9LW__mPgs>zEr@O}2|RFxo%l>Di@s;HUXS?0hE7@o99V zW}fNzTQ8E?el-|JdvLIK(H)XzBn4Z|CSv@}Loi1x4l19NLVA!WJPODMLE4Syub-t- zUp8Ukv@gujC8N<?U7Jqm=mya{Iy9mA0TF)ULZhXVnb;{ej4MLkkh_}kaPN5o96NTJ zgnpEVpcS!LHc-UYmF$6ug`2>za6KvguoZ4D{6W&wli;;R92u)zLhXBVNOOcSH%a_D zNjkcd>v6If=PYk0<ww2{uiJ6BdZ_`nx)i|OhJGp(Ce4XY|52{|;Xa{#_u$)95mL9* z2X9?@$woEHu+Iweao3<9T_!-V<m@icy&=gY-;Sd1J)g4U4s0XA?Ml4O^C;5QfvD;j z4h>2VsNBZO=xp3d>PRO!&~**x{d6MJR+vD10g{kgsxTIdP~k`x`QCk!n$75^70sKW zWjzZWd2is1#RXVZ`I2Q%6cgX&3qXC^05h_EI63rcHP|$`vtgHv@V@>-+Pg>%6s9gP z7Mzh!R)$Q4xgAIkCe4Ej{Z(Y6)M<1-bp`5X2%vV;1N?O@5K}Jt;l-kLjG6&2SLK;Y zZ)Yfj`mlT=TAxP5AJr3q3MIJwPK29zqMSw@m*B=0$ucwUpQsRc7zo>kCBpoS6(oDC z0FKx`KvL92F)g^H+|}a{jK3{~eVpZ><D>(O%U)Dj<qhP~ObDd+nY1bMvGdDDoR=Yv z-t8lxDSa;{3%(>zdQRiU{*_c9@+BjDSdsI}q7c71Rl<@z<A_RN1KGCaGh<TJO44kV z=qlxMNGN$vM|DO4>6?rhxA(E54T7;{;2mlGb_F_ioo08WET%VdBJk23CtC3G4%%MY zi6@PA5SK?6K=Jl0j+?eTekd&k&Pi{Q{c=1Ve)u_=oN}0%GWIniAyZFH`!1n>{0)>g zo&slc!l`+Z0#4s3VUq7-0c$SKC3AYG;MD#U_^SiqJCDEe!T;8Q{K;3iYPaFI!A~M= z>J4E7r$Hez3X=7ez)WWx91t50QJ!I#xOEn+49&wAP8K-#*?GuOumTx-3hxv`Y1^S# zoT|T%-s-#0JieaBxG6@%o`Ea$Ia5q><PYOu<XcjI#lrY{FW|hxZZta4gx*{gLRVD0 z2I<qc*mm@$=_VrFn;*7e<i%rjMBH?|^Q;g&-e^+VkpTkq5HshH0`tD<G8=q*zwrUV zX7JrRhP!5*0GCKgptXprNkR4uI5hV&c~x<pitO<Nx99$F!P|tC@7V`hiJ`=rl}43} znV=T@nz?(bz=#G{!q{hu+`N64@ng>kT=uh!3flP5k=<NeHR}*g++u@!%#=6>Q{O_v z{RVnsN;y>?7mKHvFZgQnC)UX0C@SB~W5XWY#jB&{W3iqBd2qCXyrz#x&G0NNxN8l$ zO-u2+kQzK$rd#QyQ;6AZX|!)t1zFai2x_*Y=sKC}p!C%Trta{kgQ6S=JJt_cJ>Oyc z^iyc2=YkC~6skOfP;_w??&%t!eN!6ArpxuLWvB+uNScOa8x1gYiXEDXjON&`NhBFU z_ONNkck1FysCSM5p17>VP1#zA)>qrnb5NRlOfU&G60@L4$bg-xcMewUY9Ke)nZxun zM?9}80>i?MKr>YqM-XYq?->OL3r6D0l1L1QK7sXG7n#Q{ndIB+CFpbOsY$8SXr3~a zK`7%UDDF*%RoXsSb+i%2bcy0N8VE5ap$r-RfTnD1$K?F!^hHuNemvs^^EIx5jnGnj zpL~*-78&A}HWggF&<vI~DY3V5TR^A%EA@ML6bgK*8O19W@U*9m>i=58=7oG@8$~Tq z=ZF?u*{1=sddI?&itC^nPzKxqq@@B#EHWnI`88;w8$1rZHrK&9gK{D|-<BS!NrX)u zBe>G-0$jOD3%Hi&#T+&AAscu}%QI_frH`Qi2Y;=G@M;~BcjqPDJN!1SULeMGfA3G? z-PXgKyjysBwHDS+i9=IWFASI7$UJcz&V9&Hf^&l$2#oTfUDuSk7IH_iqV5!A-PYrt z=(<Q3j5vX-2bRIYhtqJ>gL`Dy)luLf-T}h*SJ2kuT(VT2qPJ}#TrM=lgm7sTnpaYp zII;#WnmUmeIbN{qM?Cqx<qnod{$!K;80fa^#!u0~sO7K)dKZU~?#>3%Dcl7bZ$FU@ z&YJM-kPavQlQoqvU5ct=&&Y0G@~`2s2qvfp;E{*29MSceM9Rw;cOGzpBSA0NJa1pB z)FVB_Y*cxUgWvwvc?@x;p}_AS(YK?`IgAJ`CdJRbKx2Lvd2yOcGwTIWPN9K1i26}J zRhl>??tc`>XGqPvtEhdf6Wo2u3w(2Jz|8R~UG|$xCeMq)_)*zI4sP0C#PJVs{tH7I z8s47WJ6-pB?euW4<^T96LF(U=3gwxeaHzEk7tLD>Qy15fAn#gSAyHkqAfuf0T1$ao zfIeBQ7*Cd_RziMP3OqbKhr2xJEM0U`lU`r&5=XXj8H>hLxG*LZsYx!`;o*cXLetP> z0t=(xj6mhgk3=@$6Ykp51+I(61M|v=4zrSl{=LOGb<+u0_v#!?;%zRA%YV?Ab7Sz+ zBRwz_@hAOHG~ndeQI(_YgkinMW7=6H#rb@(1|1yJ;KbMh7&FBhopYr)d!HPE=#>3% zE8>kw@Se?d!EQ@>rnj1mwl0F**@nb+?M}>;Tm!-Po#Fb4DKP9Q0fo{KYSQ3>o%1$h zvt>PAIO|Skwr8`W&0f)FO@=$}LMmSNiiEDo+UT8nh5Qh_gg#TQz|VuSRL3Wd#6BIv zDH(qirhZMKi=L0g$9d=J#p6Q2J=(~+ZMsbp?^eLex9{2g(_}eJV;QqDunl59mym>j zeW2)Li2ae)7`M5DJdimI0&SC-;0qb}Sur1X<d&e@yw})l_>$Qp7zJ&$8!OYZM{!Fm z=hK>E4f<rV2<NDOIld98p>`XZAj<wNS*$+_Am;;P({+J3%PfMUMk~<0W(!2o0(?60 z0(l`ihitu}g!><tRo>Y=5&d+6@z==9)MrIF@;At-f7lN{-uqh{AF`iq4tw_Q_z(N} zBQPjgc#a4d#G=Xq1GsE-0*ePmq0rljxZ`vdT%V9d@cw$Lx&Aj{nnk!zb7!!j`M=>w zpeT`+4`n8=9trkl4)8PgH+ow(;z>&@q9HT~-Mcd(WOEuVKHZP})iULe;r;8OZMEBL zzr&Wl68{*M`SP=P@%wacGTlvYv~#G=ny2)#LI{l8nFPxZ=ngGH^?ye!e<?Pc7UOD* zD8txkQV_G_G%+<j2=e3N(IZlTBYnzqXsmw?R{nk&{$;H0p02JtUANo1IM}($)&dCq zXvBx=uDDs@DrCh4(y7C(V0VT)S@z{I%Ln4G>->K$e4m~B`IvmTXz5OjKDB~++A?S# zF9}n9Zj(*%HQ?`RIOOAI|A7Yo|NM`#T|HfFhvwnnVY}De&feK!8~^KkTl(u5|Fy#Y z*%$rEwJvsy!AHs!v|FhVyP8#S&M_mL?Y|IW)mDJz3OVXlRzp{tOA{H1Rp@Xh1}!h` zL1lk6EcqCRJ)hOdxqC)1LW8#_G}W3sHL#;bP5DIS=bFk;>2mxQpN1vkJthhT{s5OH zxw66+N$|o%sK2C0M>mZG%fVT=VBZ&3;XoahiOX}xZE%A$%|rOv=soS(cM)5?i_u?b zCXKmagie7~M7gM)>@dq9*VpVNE56)<59g2LR$)8#$B!s*Yfb~}!OJMwnM&V3&I4~5 zArR|bL@M8HrV+l^sYi4y(VgQ<KRW8NHUnoEg#;Z;GZY)r!+&d)Lwk*f!!Ac#yS;yA zdk6iDI7|1o;y~LuQlD&!7w*e~WD#LLh|LE9uY@6?lla=<wZB!!KMUm>y2OZuWNdpB zd>51F_Qc)+buliv@iiD)C#Rrg-r6Cb@#i@CITpU&|BKE1IZRjGJ-a<@`C@H7pYD+( z60xC`o4cZjLwyLapI@Q=Qg3)#C~BgYyK*Sp|Gx;{=ew*D?^X){^%>9A_^A$Seo`#> zYR%k{`og*t#SiV0f6aXSEfD`X;(z22Y~B9wCU!jrRHjG|OA7_@LFZYxx<ni-+-^|e ziSE=isRqZsy-70Wy~0XQ4%l55CiT}8IHW!deK`%Rh3;GMJemk!jAr5$D+h8`RM<Gs zCW@3b&&1Q6CG3x|yI8wv1X0}imYjXLgZ?bLhc2TIgJ^m=QE4{B=h>0;lN=W_IQ}S@ zFOF|aPQlTVNhtAV5{#`A$No2m$>6O~G~LG*9NvpzviU8%+**h#T4Ff!RVuC!Z6Mm- zT{Pm`Sk7*j7Mv5fl3d>sh2v8P>4N*C$m@nN+%P|P;vLfqGFtCw@bEx#c7{K=v`(RQ zWn(b=oD7$*-+v!H=ReRJ7&{pnueY;jQhhP{tRRuyC5NBPMDeZu1W;O;gDz6Xz^?rg z)}GhG@sqxS*0Lx#*;dV*xVZ<n9TCONG6~Gri{qG(@#XCPxWP(|z>V~3%oQAIGz@De zJ3*X!0XB{|r4G+l;U$SSrtAAn%4Z(C3j|^6`%FxFG!+t7n4x%yEVkyJA**iwptEAO zfrXhFntLeYM^Q~IsvU(ZE``CJ0Tpn6(FAg;GMwNpe{7!<#Z-hR;B=WqD8Eb<n^p!P z$vB3KQI@*0$#^#23f}ZM;rF^(oaYnT>EPj5hOIq^bL^Vnu0<ql9G*!2Z|Lp&13l4g zxhP#`MI=Kb(NmXKoH1nugnC|M;yNdAHL4DxlFSUa3CXbNTR6y7-@#E1C&~6fU$9#6 z5PN(DD@W^xk@}xWG()x%+pEsgxVTJeab_&dhEaI9;Wd>?S7Ubuza|f!CsSXxo+?dd znb?<o*vot&bt$st`SEH@!Dmls?)!!(&JW`78DYe4_WO$Grn=-~@<%Ess}0Uot`Ky$ zpR5QCq+8oWxY-{UK(bv5)Cj5r%&^1Tho?cFg$*nTss`443zoEQr-sV_9%wi-Cx%60 zg84%-b<!AUQ!i#-G;JZ{Jp(ArW#L+^<B+fUANQ31vU30PFf7wr9^(gRnWzcdL1z}? zz0yoncM{^db_{3EO^(E>*6UPmr5320Ovagyc9DA1Bw8Rn4F>1;l4vzO99cAtZcPe= zi#u+S%(_dM@OTOYJ`6#l+G5C+^M@AaIC5T4m4uJ9q*$Z_>5o}zWavi5&as7v25GFL zN;vQ`5zldsW3$)?lzwak{vqw4>sA86lkTHzlp(~6mcZ}V?`iI}I%c)60QX$<aZ;{! z94ERgB}X4s<JEx?U_8km4jeJZ{A1B*J}A!JxcxY_Y8`|3f^MSh=XMgf^(`&E7>u3+ zlIW)r46@eaAzXd<kW2md@#DGdKkP5(urgdrj6?r~VNCWjd9pL2n>u`HA%|}-!W4fY zIBg~gc*%&{y)hAXj-1ckc3BCNi9hV{I7@%z@QPJ_hvIoRe`xOQ!x~;87kkN$-FYSv z#TLooj@f?riMt85`51$@WHX*gzeg*-*<fY#IT(3A7wxQSL1v2>yP!M{8|`$k&0rZ` z(n!aI*;g6MX{T^vs08=RwyX5lwYPXGoPz`9J8+noE!Id2aO$PEGa;r2$#A<?RCSty z7W-|;wMqVXZ{H61+LMIwr(@_rQF+XpuSqofCt=K`5GsGyjXoAI;T2X55IM!Y#OUDV zA>Q>@-}hhkxuNazuPNw1*^IbDGw+zw0*)VXqb{PSX!-WXOqJGCxD%CwR&oo6HfBwp zXXUrQ-H-l8GJhN0z2-AMJG7asOt}TuW_~ADS9Hl~Z)q&D{AFS%<@@jF%xCKUbI$yl z{$h-CileTREr{ig1$b!3a`fbdCW?|ZWcf)6aG2{toaDTqx3`G7aJh!WMJ4lU{VLe* zR#D6=mgFcqw$rcsv@zJAg^cR;AqMx0@MBRV$iL^ZKW9eZ80|zbo2CN(&1qPEpcOu! z*Mt<wDV0mY?m(aNByvFNZKbQY7HWJM1IvehAp-$>Ou|f-;JT5+N%rRmc5y*AZrE0d zZF!@>Do}$v$!axOR{e#Z*_#Q^GY4tk15H>m_YS@D<sGv+%>>3j{74OzkHX;ABlz-R z0#Tm13+4`Jg3R(hxVmRPehn%_3_F6?M{7aVr(dMUxt@_?BhhKq0#Lb@g8SCT!NQjY zoI|M<B-Tt0l-Cc^(U;q?-eNrdYULea%_+kr`<fxWIF#&kyh?QEo~Fvf_2`|C`=HZf z7#sOa6s#|s(`?-WdQGkvUu*nCw~jz6yK_79J~aaa^d;HzQ*9V=-m*?v=!9<q?vto# zH%V$vEb4RA2-D&R^SaDnrusGZ#K}+eilQpY9{WTrZ!IFJDyq;WE=mGUu7ZlPD#qrO zIgqZ+;Munf7r3>fPV^19xx=5B?R|zXD-AFrX(80=MN&_BVH~%z2z(`vVfC7&c*%1; zx%|GG<f;53Z_ReX%=H%Jc+__yTN*$Iw)^lR_GYHg)P_vGu@FR+rJz+&pLwOW0*z0E zFcE=oD!0`yBqehsIp5?&aDl-SvWZu>WOTWPnU^ZZ*&I<tMLwM)!AF*&<c}hH{M`|< zHFG^UZb>J3+G#ND@KyGgOg2@J(%}R=y+wL>H7W}$^w4{a8wy?94Ox-pbcM4Jc|6LG zm7i}6hS%4lb#5@~<vWo5PsTysyli4KcLko`tphnO1S9;1WA}MC+!%HqK1@-Ar!U9D znT%9$7n;raSlxsm!r*yU30Ku^K<Elam0d|Rw=5ZRLvqOJsx{2E#b3#h?}Bh)^LX%_ zj`TiPlv5(LhrD&@!l%o8u(qZJtpy`YhIxKtk39NBg^L6@lS4*A{3}CvWbKMytB>NQ zDtYFXrXD$++=<)I&ZUVFjwYj4+R%+gEAhMX2#g!APdD|IF{uy6!_f^PfW--PdH8#} zW6VpkZi^xWu4th~R!d;Q%>(db-8mS3rWHJDH{q(w_9kO*&NaE)RKqG5Xy6+Sie>j$ zrti`Mv{ou)_Lx~<a#a|<@7PVIg^ni=EXIJ|8wt*vfpC&;GeC}a2IJPU4AeY%8>a2t zg<n$xQBOmKtu492=8V{m9_lIo<UmyZm2>)MkM|D;@~7vUZgUEf)-A_o`%75A%nhuh z2FZ#hW%LNRjp?a`SHQM~rL*JVN`3%dRmfw_A5Y^IfGTh&tO<Y+>l3*0M;7rrd5Dxh z2mw2BNeEI`;99hs<GDadQg>kybqM}KI%Su@k!us6&1oaOvGyZ2&VEnzbsv%H*ibY~ zRt2xH9pL<4kPM7jz;#Sa$8)Yd(EeTmcMSZ*1jW;^v)<a|qH8z=)rOP#H6O@^<}#8P zF%x!G$#7o8?uSV8+gM-NMkEwtVEpGK7;tW)Zx!Pq$j1d#GA47Bcxh3=H-ea6kwP7R zeW#93>@lxy2DdA!h59J6BsOS}#+3Jik;X<7H_kbD^CXbCMy}!IV<U*=>;lMAabSXl zf01926=cfESo&l`9hoj5&Drl3LdRc!Oxk~)frj2_nEMq>+=3Oj>1M;AMM8?(u2Bh* zAvz@e!%?W}cEso_H!;aKkX%a_XPKNI^v%{hCPzt+<CVS+2G$(|&dBvJ%$5a>v!>L* zSD0I-IFG)sT>*^?_T!?Y*|cG-9v(I2f>SfV<~tfNUv(nn#wf5emo6cy#V0VkOM=^y zKaNPgK1XzKorL46x1gl22E30PMLC;dx}~d+l#fv3w0W?w!#ony!){XBaa*yaZ!zqD z9fnz&duj2g4mPyy4z!-UOmh?!xr-+W!nod<*m6pjF^RdzcBzHYYipK3$L#AUv2hqG zDouy1*-B{FS&MEOQz5!465cGG%l;hvj=RfegTB{8>>hubJe*;Vr#$0O%=9(gc*&SP z`Y;(yFUy0zeH*>qc>uclB8g(`Of<>ggj%l}@JCY|@d>+3B^$<aQznb!)NDt*S8xl? ze|d`2k6gh+vJv2vod>f^32wQsiYKd-VD_gm+(hFpTv8cH?{B>c!xUc7-qL%FkV-Vn zY82sE<;vpJ&fv-khVQ|p@hVma`GfX)KNNdjcXryPzaM&};H>BquJ_A_3d^64;o z_(UY*HFq=!Rqn^bHVo|Nu7dk%8l3VdE-ecFKzo}O!2QBi;Qd1kt(@=C@SqV8)w~A% z^%|(tp;G2?-x2tJb3D#_(19{*`p{!`8orXdfaa@}sgX`1a@%L&5p6SskXN)}Q5HCQ zYQbaG0F!iyk2LrGb83GtlS~McgVYT>@X0PYMoL78>vLy4F;JWjnhue)&;K1fNwmYF z3!BM6<!JW#j#`$geqdj3c*Z=C-2hn|3ZZeoY(>PCl^~Eg3opnmfR{4!dBu~VF!`V) zho3{@I}U#P*UaWWIl8jwclPY+kzBv!BT+v86u2cwf`01}-rDm8^AQ6>1zmquK>YXe z`25JZB^j)3oq5sFRhU>@O5|`mT!mBA^4mDv<kd8k3;1gSUV=BuZ4UoYPq%HS{ob9s zU3DE?wmNM4g9KmQwuxkpTn;;&e==?A@_3?WmPzKlQkXMY3c_Wcp~OxbI$Ny`qC2gr z{Jyv7vf?u=x~L27TO(O7ZA$$@Lcj$_pv?CeSW!Cy-i99_j|AmNgZ3;M+!zRoL7R{y z*I}=P5<E^+g{iOm$jz$zq(v!%dL=Fftz;v%*YO*jTbN3<Rts`2$+n<T+ix<q^D(+* zJ!f|fYoQa|hI2nnuO=Jfu2RPhBHT}%mSCP=3J#SNCfzC}o|z_$xL`J{U*8DT3ek`v z=7#HamXrBnuSif@JC0b{O^Wu*aGo_xg4)DN5Lj+Z4)?~hf(F;fu%?wbP__>j%AbV~ z7YabV>@eQA=LMq`F0tEw<dN#M4pJ3<1hjT!F*AIM!N&PC3AsHHF0EXL`=^ZHw1ssO zfgl4Ads>SpQ&!_Q_tW6e8bd0-e!+sb0YfVJYmlPy_bTCO$=2P~Rd<)i-aiw)wmWy& z?wVn(FM;Cci$Sloh)p}aiVU_Vz__aps6X`xzBX7)ld9Ih4C`#v^m~L`UaW>&HskSW zeGTo;>BJNDIw;@YUqN(S!D{zHY@hrPM!|HP@In{WYOaHD@OpA&l_c+ASD6Or<w0<6 z1=x&Rg|c-zta`>3#`wS?6qet{Zgb8z-ugKj+QUwQTEqo<Tz(^!mez*%_vCQmMR`~y zR?PZp??+DWUYu)K24AMB(24!^kbPJfpDM(Y)2ib*9_xcJq2n_S=4ZgQtTpIxCabdY zmmHp58ceJ#J#dccTsm|4Te4u3FJ7KI3$>dsQkT{e{IKB(J!#R)<VWA7%{?LXp;&t5 zy_P9>;Qb6x<JswgYn`;+d4SB?>`66MD~OjfFY!A^2zXH^D!acU2k)JN2_Lr7kFSov z&YT&j!DAIE5u=Fl!<Tscm?wqO3{dp=g5fvQ@$9$&e3IwO8bbwbmoI{8EBAs^+)tvU zB1fFo-6AiuZE%wNM)0zVzy;6NLBN9((Du=a%ygSgclF&R%9G}BPgI>GlQrJ4F19P! zRWD|Phe{H=^m8;mzcY#+w+><UZ|z|8?x_=r5rLp99EuGeW)Ww_bHEjC!DW4O$mVBq z$mRJYue9x0d|HQo)>@3bvN`<xBn%?HOoK`{X=r%rPhVG$hF#prsQNL0PFT1fvhLf0 z`{sjKBfJq!Qd^jJJ<Y)A_2C8I2r@&h5CiPha8p7#eX+EWHk}v4*;*fs{Y1p^o?Zn8 zpIZdRl^if@vVz+NH?USBiS&-t0oPytxXnV|L|>_kO}x_qt?Rx*n)oC1NtwtoS?Yv4 zBo9J_`+kUca|DyZ2T7(;GaK_x5^VH$kxbPjkaa%E=(bO#k1WFAVcZy)xZW77)wFTi zoj_zPWl*C_88<Fi3}c(T;Xsoz-Kbd#g4zWT&3F@^?W?ip+$~s;-vQ5_q+!kFc)WA& zJW41&huftesgQU>WwhZvT6Dt=*WKAepRZBk*iF9#YSz4ivubt7YgvoKwzT59qki!5 za2$+y!^1nKo0?W@)AVkDsqQtTDdZMKrVbtFyk~wYhCu9+B&-aa1dlL_l!|lUmZ2?p z*XUu;z&)&t6T_#c?BND?1U|^sLm{>ISY`PV2JF^A``CrNvPCb@N|%N8dIxCr#VNG( zeG^Sg)#PwiN&&~KlP0d~B2%{cVqPp49vExkq(uVo#qBt36}UkTc~>z;KW3QJ=x%5I z&&uLOwYPMxJqvryTiE~q;Njo!A3ON|V{NB5>!Bw|N?zvCrMK*H%U}g&39G@kFmtr2 zo6p%dEEM1OoI#I7AqeGOgZM=O%)uk0Vbiw`bac}}W>+vikNS-3Vm~p5L!`M^m@GUV zD2^I+Yf<X!Z6f5ELVnk6gSjajm?9Db58Qv03vWp#QJ;liuapdt+;oS&)6b@U7OPOJ zg;xR`%4Mr=XyMD_TDZRJByNa50Lz5Zn0+s8aBy%r#PY&@1F>S<S?UKfj*6q<#(KIn zSCiSqi2+d0Fv&fXgvX2}Aa>;dJ(Af2>DOY)8%9WTLkCJo{JR0TTFN7FHXpzz$PXL* zcEZ_*b8%v>1DVi$7NvenMqgzeydvrdT|bt>s)}%u(!{%S6PZd63j&T#-wfC8CBV5k zK_G285B=9p=hjVqjW7By!xFbzdU!gAcSzp=*Sgc7(7eW^Yi<G!Stv?3o!UV*+`9|s zmpvgPo*Tls?0D!YjmI=H4@&D3XtAg&w`-OdH-D8M!Y@%k>G#Z;+a@I7$~OpdJBuGG zURHi<NhS`5Oz^PeJCn(G?qZ!+F<2bY#eG6ToV&uqS<@5#pt#nWJ>2x8^0S{DLVqt) ze*Y}dzQGF~Bn0T5>56zsN0xK_dMP{g_6?Z3%A8tP1(5A=_Aql!2_$tN0SoQjD876< z%IYSwo&xRkVsSLOsK2H*y`SlvUsF-5AdMbqOaSTOOUcBU%ABGkUE<>(N7Y<{j5}8Q zljNr@r19M_&PT`f==NbF#8$2)72|I4ZqqD;!}ECi=kASYV>`NX=aXiPQAxu`>c>F+ zUKcg(495%oPNbo7I0j4i(ESS-$S;b+VK=v<dbJW^rtSyvL#N@;JZ<dm-wolcIT}^H zhn*9nF;OBLyrnf!?5P0f?v}CeaHJpZx;Pw<{8$8%>@l!>5XGMCFkoCJneq;-MM!k? zIh2{W63wsis{7-T!BZ^@i{7@;?Fm+7n$R8CyTYDul`g^qk;zcF{65Tk?8ff=rod?m z9H37O7NBtJN%Wg8$5~aq3SP+M(C;r&A#vmkkjef5_7iWS`ST!pvhE{2n>L8GSAxKq z7u4)u&w#JUSkA`gV!A1<k<Jq7Wk}^Am@q~Li;k({_f<h4{Je@hY!>3CKTW}-np_CI zRZGk!&c_$Ln=ffC>rF&9m_Y2aJajX+!e`u2BD?P{yUAfR_-r|aF;7c~a9JprbH0%k z9#SNA^L6&#m38>rBbJpQX)yOjFFW(LJ=Td1(xRMY#E3zf>peimJ{>-kiqqn84t`VN z^}nHm|HM6pPP6!qL!~Pj26xB8mWoKoIIaf!coD?sz5=GWLK2-7MsSUeabQ&IYDP2t zEHkamnO+Dq;$3M-Ad~wLYvf~SvHm48YUvO5yihIVF45*##w`Y)umv37QzN0Onu{6= zi_rOm6PU-OfWY)t<D&ax@$~m7x+JIw)p^(F_>ND8*ZGaljQQP(U+2NU8S*E4UEXU0 z?VI(<ybu*|in&V87aYRth&X+*0Xm-^1&^|3K&3(fuGKn--_#AUu)BojTjk@O;U#$2 zel<3oD`ZNQPT)4}PB!aE5L`KP9s?ZLP~)loU^eA39o9S=T|46-@Ma2C*b)xbY0rpd zZ~^_D)JOVcGRTr#KiINmBwno6WLo+UvXb*>!)QZ)Tt3YY&Ry*y0?`+6<Asy3^@kKk z!NdplWR`$lV?Lvo`w%CO{D?{?{m94Av6!psgzJ2L*xw0C@U@qhb9?ZHy3MZw$@E~% zxv&}1bO-3<G6^v941<j83=~ja3SfII^1~5nUgx*}4<NoGw}(Z*xFyz@^*Vq&nHxX` z?FL9fU@qwrUXPwNxkDKLd42kyG3Moh{|FGsYFlu&Af!e)7G4|9VO^Sw&}^+2%t>SE zm<a`>9o6x`{sstl+=7YMa;c0|Gm~~V6rUMp5xF(@;F;17s<rYq;eMUYHQGCo9kb{R zl@4Bo1G_?Cr+6mKs??`06RW79bsv6ei$cQ@`@mIPluL9L;DifzslfCXFz4C=C_gKO z%fFuo_ZSAJYx}~*PkHFSB^bEfbyW1cKfbZvOr-=ONY09%^lhCa$b%e~-&ck%r`Kd} zh&TIu@D=sfnLu2hkA#3P!klM^_wY!EH%!l7$ay9=3}#LmAnS}~!(DY>lwH(HSHuaT z*lqQp1tj$kg#Q9DA0Ix#^}I_uZXuGSR4)|V`)xq)q&mGDFck|{Mnk2N*%19#_{+B? zdGG&W#Y0b9SBJg2f994RB-LW+x-+!(kQUi$w+MC)?f_MrXj0Q42b$LYOsMZPbiW=* zcf=-}Jd}`P_geR0x8EK5c8()cE*(Yf@DLSj`bHK`2*E=8`?zxcB`R=gKCpsMAi~fb zWV$Obug8VPJNOg#CVlc;dIVg#=7g^W0-1=}jbxhb8VH-e9JgO7q#Go3xM~U$QHmpi z=<A2Z30h>t6bJTe!YJ;<d>eB9ybBuU4v=Em3^cpnXL5_Ta;MEmCX4NUGp{oIXiUa+ zoFd2KDQ`Knz93pTsbm7SNveRv*9a;r;f3j2BWUfS?PQ77c~b7$4H?W-{BbaW?U>d? ze0j&s{+a?Pz5FIh-4sH(-NEP~_?%ZXY=PrNkD<5+506*l;Ow^TSZESR-!yfi%J`pj z<D$Dnb^02Z7h-_A>Bm4%=qfI}l!}hI`mn715h$k*5^18&wjCbHy}4VE`^kS8`MIW) zS9BeYpY@M1zi)Lg%ndi#|4M?0^3I~SeG@XNn*EX8YJ42VA3QL$*yR7v<=^X(ZxQb% z&gO7cc?GJ2SINa)A|Pbe4cA{S0dMzO?8qpB-@*A*dG9kOk*1^0WfAb=YT*9^s#p(I literal 0 HcmV?d00001 diff --git a/models/generate-parakeet-test-model.py b/models/generate-parakeet-test-model.py new file mode 100755 index 00000000000..192a96ce627 --- /dev/null +++ b/models/generate-parakeet-test-model.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import struct +import sys +import numpy as np +from pathlib import Path + +def write_tensor(fout, name, data): + n_dims = len(data.shape) + data = data.astype(np.float32) + ftype = 0 # GGML_TYPE_F32 + + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + data.tofile(fout) + +def generate(output_path): + rng = np.random.default_rng(42) + + hparams = { + 'n_vocab': 10, + 'n_audio_ctx': 3200, + 'n_audio_state': 8, + 'n_audio_head': 2, + 'n_audio_layer': 1, + 'n_mels': 16, + 'ftype': 0, + 'n_fft': 64, + 'subsampling_factor': 8, + 'n_subsampling_channels': 4, + 'n_conv_kernel': 3, + 'n_pred_dim': 8, + 'n_pred_layers': 1, + 'n_tdt_durations': 2, + 'n_max_tokens': 5, + } + + n_vocab = hparams['n_vocab'] + n_state = hparams['n_audio_state'] + n_head = hparams['n_audio_head'] + n_layer = hparams['n_audio_layer'] + n_mels = hparams['n_mels'] + n_fft = hparams['n_fft'] + n_sub_fac = hparams['subsampling_factor'] + n_sub_ch = hparams['n_subsampling_channels'] + n_conv_ker = hparams['n_conv_kernel'] + dec_dim = hparams['n_pred_dim'] + n_pred_l = hparams['n_pred_layers'] + n_tdt = hparams['n_tdt_durations'] + + n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch + n_head_dim = n_state // n_head + n_pred_embed = n_vocab + 1 + n_lstm_gates = 4 * dec_dim + n_joint_out = n_vocab + n_tdt + 1 + n_freqs = n_fft // 2 + 1 + + def f32(*shape): + return rng.standard_normal(shape).astype(np.float32) + + with open(output_path, 'wb') as fout: + fout.write(struct.pack("I", 0x67676d6c)) + + for key in ['n_vocab', + 'n_audio_ctx', + 'n_audio_state', + 'n_audio_head', + 'n_audio_layer', + 'n_mels', + 'ftype', + 'n_fft', + 'subsampling_factor', + 'n_subsampling_channels', + 'n_conv_kernel', + 'n_pred_dim', + 'n_pred_layers', + 'n_tdt_durations', + 'n_max_tokens']: + fout.write(struct.pack("i", hparams[key])) + + fout.write(struct.pack("i", n_mels)) + fout.write(struct.pack("i", n_freqs)) + f32(n_mels, n_freqs).tofile(fout) + + fout.write(struct.pack("i", n_fft)) + f32(n_fft).tofile(fout) + + for d in range(n_tdt): + fout.write(struct.pack("I", d)) + + tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)] + assert len(tokens) == n_vocab + fout.write(struct.pack("i", n_vocab)) + for tok in tokens: + tok_bytes = tok.encode('utf-8') + fout.write(struct.pack("i", len(tok_bytes))) + fout.write(tok_bytes) + + write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc)) + write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state)) + + write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1)) + + for i in range(n_layer): + p = f"encoder.layers.{i}" + + write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state)) + write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker)) + write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state))) + num_batches = np.zeros(1, dtype=np.int32) + write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches) + write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state)) + + write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_out.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_out.bias", f32(n_state)) + + write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim)) + + def reorder_gates(data): + h = data.shape[0] // 4 + return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + for i in range(n_pred_l): + base = f"decoder.prediction.dec_rnn.lstm" + write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates))) + + write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim)) + write_tensor(fout, "joint.pred.bias", f32(dec_dim)) + write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state)) + write_tensor(fout, "joint.enc.bias", f32(dec_dim)) + write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim)) + write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out)) + + size = Path(output_path).stat().st_size + print(f"Generated {output_path} ({size / 1024:.1f} KB)") + +if __name__ == '__main__': + output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin' + generate(output) diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..5239ae0af5d --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1,3 @@ +torch +numpy +pyyaml diff --git a/scripts/quantize-parakeet.sh b/scripts/quantize-parakeet.sh new file mode 100755 index 00000000000..7816696bfcb --- /dev/null +++ b/scripts/quantize-parakeet.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +build_dir=build +modelname=ggml-parakeet-tdt-0.6b-v3 +model=models/${modelname}-f32.bin +cmd=parakeet-quantize + +cmake --build ${build_dir} --target $cmd -j 12 + +${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k +${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..3644bec8bd3 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,157 @@ +import argparse +import os +from huggingface_hub import HfApi, create_repo + +USER_NAME = "ggml-org" +REPO_ID = f"{USER_NAME}/parakeet-GGUF" + +MODELS = { + "f32": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin", + "description": "Full precision (F32)", + }, + "f16": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin", + "description": "Half precision (F16)", + }, + "q8_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "description": "8-bit quantized (Q8_0)", + }, + "q4_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "description": "4-bit quantized (Q4_0)", + }, + "q4_k": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "description": "4-bit K-quantized (Q4_k)", + }, +} + +def build_model_card(uploaded_variants): + lines = [ + f"---", + f"license: mit", + f"base_model: nvidia/parakeet-tdt-0.6b-v3", + f"tags:", + f"- gguf", + f"- asr", + f"---", + f"", + f"# Parakeet TDT 0.6B v3 (GGUF)", + f"", + f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).", + f"", + f"## Available files", + f"", + ] + + for key, m in MODELS.items(): + if key in uploaded_variants: + lines.append(f"- `{m['remote_name']}` — {m['description']}") + + lines += [ + f"", + f"## Usage", + f"", + f"Build parakeet-cli:", + f"```console", + f"git clone https://github.com/ggml-org/whisper.cpp.git", + f"cd whisper.cpp", + f"cmake -B build -S .", + f"cmake --build build --target parakeet-cli -j $(nproc)", + f"```", + f"", + f"Download a model (e.g. Q8_0):", + f"```console", + f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models", + f"```", + f"", + f"Run:", + f"```console", + f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav", + f"```", + f"", + ] + + return "\n".join(lines) + + +def upload_variant(api, key): + m = MODELS[key] + local_path = m["local_path"] + + if not os.path.exists(local_path): + print(f" Skipping {key}: {local_path} not found") + return False + + print(f" Uploading {m['remote_name']} ({m['description']})...") + api.upload_file( + path_or_fileobj=local_path, + path_in_repo=m["remote_name"], + repo_id=REPO_ID, + repo_type="model", + commit_message=f"Upload {m['remote_name']}", + ) + return True + + +def main(): + parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face") + parser.add_argument( + "variants", + nargs="*", + default=None, + metavar="{" + ",".join(MODELS.keys()) + "}", + help="Model variants to upload (default: all)", + ) + parser.add_argument( + "--no-model-card", + action="store_true", + help="Skip updating the model card README", + ) + args = parser.parse_args() + + api = HfApi() + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + variants = args.variants if args.variants else list(MODELS.keys()) + + unknown = [v for v in variants if v not in MODELS] + if unknown: + parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})") + + uploaded = [] + for key in variants: + if upload_variant(api, key): + uploaded.append(key) + + if not uploaded: + print("No models were uploaded.") + return + + if not args.no_model_card: + print("Updating model card...") + existing = [k for k in MODELS if k in uploaded or + any(f.rfilename == MODELS[k]["remote_name"] + for f in api.list_repo_files(REPO_ID, repo_type="model") + if hasattr(f, "rfilename"))] + card = build_model_card(existing if existing else uploaded) + api.upload_file( + path_or_fileobj=card.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md", + ) + + print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}") + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..4e7c5b24dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,23 +109,43 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories(whisper PUBLIC . ../include) target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -144,4 +164,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..3407a95c9c7 --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,188 @@ +#pragma once + +#include "ggml.h" + +#include <map> + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_H, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..b5da73e985c --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,3838 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include <atomic> +#include <algorithm> +#include <cassert> +#include <cfloat> +#define _USE_MATH_DEFINES +#include <cmath> +#include <climits> +#include <cstdarg> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <functional> +#include <cctype> +#include <map> +#include <random> +#include <set> +#include <string> +#include <thread> +#include <vector> + +#ifdef _MSC_VER +#include <codecvt> +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template<typename T> +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast<char *>(&value); + char * target = reinterpret_cast<char *>(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template<typename T> +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast<T *>(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data<int16_t>(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data<ggml_fp16_t>(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data<int32_t>(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data<float>(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 8192 + +// Threshold for when local attention should be used. +// 8192 frames x 80ms = 655 s (about 10.9 mins) +static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192; +// Window of context in each director of the current token. +// 128 frames * 80ms = 10.24 s +static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128; + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector<char> buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector<float> data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector<float> data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map<token, id> token_to_id; + std::map<id, token> id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector<parakeet_token_data> tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector<uint8_t> meta; +}; + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_conv_kernel = 9; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time) +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w = nullptr; + + std::vector<parakeet_lsmt_layer> lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; + + std::vector<parakeet_layer_encoder> layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector<uint32_t> tdt_durations; + + std::vector<ggml_context *> ctxs; + + std::vector<ggml_backend_buffer_t> buffers; + + int n_loaded = 0; + std::map<std::string, struct ggml_tensor *> tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; +}; + +struct parakeet_lstm_state { + std::vector<parakeet_lstm_state_layer> layer; + + std::vector<uint8_t> ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_predict_us = 0; + int64_t t_predict_build_us = 0; // time spent building the prediction graph + int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph + int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_predict = 0; // number of prediction network calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames = 0; + + std::vector<ggml_backend_t> backends; + + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector<uint8_t> enc_out_buf; + ggml_backend_buffer_t enc_out_buffer = nullptr; + + std::vector<uint8_t> pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector<float> inp_mel; + std::vector<float> inp_mask; + + std::vector<float> logits; + + std::vector<parakeet_segment> result_all; + + std::vector<parakeet_token> decoded_tokens; + std::vector<parakeet_token_data> decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + int32_t sched_encode_n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector<float> sin_vals; + std::vector<float> cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector<float> hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector<float> window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast<char>(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +static void parakeet_sched_free(struct parakeet_sched & sched) { + if (sched.sched) { + ggml_backend_sched_free(sched.sched); + sched.sched = nullptr; + } + + sched.meta.clear(); +} + + +template<typename T> +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer, + int n_pred_dim) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static bool parakeet_enc_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_audio_state, + int n_frames_max) { + pstate.enc_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.enc_out_buf.size(), + /*.mem_buffer =*/ pstate.enc_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__); + return false; + } + + pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max); + pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.enc_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) { + std::vector<ggml_backend_t> result; + + ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_conv_kernel); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load TDT (Token and Duration Transducer) values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector<char> tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at("<unk>"); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<s>"); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("</s>"); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; + + std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token + const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates + const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank + + // The prediction/joint hidden dimension is 640, which is not a multiple of the + // K-quant block size (256). For K-quant models, we keep these tensors at F32. + const int blck = ggml_blck_size(wtype); + const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype; + const ggml_type join_wtype = pred_wtype; + + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + + layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i); + ggml_format_name(layer.b_h, "pred_%d_b_h", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out)); + ggml_set_name(model.joint.net_b, "net_b"); + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector<char> read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector<char> tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// conv subsampling + conformer encoder +static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Conv subsampling + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + + // Encoder + // cur: [n_state, n_enc_time] + + const int n_time = cur->ne[1]; + const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD; + const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1; + const int d_half = n_state / 2; + const int mask_dim = local_attn ? window_size : n_time; + + // mask [key, n_time] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + struct ggml_tensor * local_mask = nullptr; + if (local_attn) { + const int chunk = att_left + att_right; + local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk); + ggml_set_name(local_mask, "local_mask"); + ggml_set_input(local_mask); + } + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding computed in graph. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); + + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb); + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); + pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3)); + + if (local_attn) { + const int chunk = att_left + att_right; + const int n_group = (n_time + chunk - 1) / chunk; + const int n_time_padded = n_group * chunk; + const int n_kv_chunk = chunk + window_size - 1; + const int n_kv_dense = n_kv_chunk * n_group; + const bool need_padding = n_time_padded > n_time; + + Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3)); + K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3)); + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3)); + + // content bias + struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head); + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u); + + // position bias + struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head); + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v); + + // right pad the time_frame. + struct ggml_tensor * Q_u_padded = need_padding ? + ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u; + Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > K_padded->ne[1]) { + K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0); + } + + // Create a 4d tensor where each group spans a wide window of + // 512 keys (n_kv_chunk), but moving to the next group (nb[2]) + // only jumps forward by 256 frames (chunk * nb[1]). This creates + // a 256 frame overlap, shared keys in RAM without copies. + struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded, + d_head, n_kv_chunk, n_group, n_head, + K_padded->nb[1], + (size_t) chunk * K_padded->nb[1], + K_padded->nb[2], + 0); + K_chunk = ggml_cont(ctx0, K_chunk); + + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded); + + // The above mul_mat operation, combined with K_chunk's overlapping + // frames, produces a dense matrix. But some of the results in + // this matrix were computed for keys that aren't part of that + // query's window. So we shift each row to keep only the results + // that we want. + content_scores = ggml_view_4d(ctx0, content_scores, + window_size, chunk, n_group, n_head, + (size_t) (chunk + window_size) * content_scores->nb[0], + content_scores->nb[2], + content_scores->nb[3], + 0); + content_scores = ggml_cont(ctx0, content_scores); + + // ungrouping. + content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head); + + // remove padding if padding was applied (truncating to n_time). + if (need_padding) { + content_scores = ggml_view_3d(ctx0, content_scores, + window_size, n_time, n_head, + content_scores->nb[1], + content_scores->nb[2], + 0); + } + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + + // attention_score = content similarity + relative position scores + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + + attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f); + + // right pad the probabilites. + struct ggml_tensor * probs_padded = need_padding ? + ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores; + + probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head); + probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0); + probs_padded = ggml_view_4d(ctx0, probs_padded, + n_kv_chunk, chunk, n_group, n_head, + (size_t) n_kv_chunk * probs_padded->nb[0], + probs_padded->nb[2], + probs_padded->nb[3], + 0); + probs_padded = ggml_cont(ctx0, probs_padded); + probs_padded = ggml_mul(ctx0, probs_padded, local_mask); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > V_padded->ne[1]) { + V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0); + } + + V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded)); + + struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded, + n_kv_chunk, d_head, n_group, n_head, + V_padded->nb[1], + (size_t) chunk * V_padded->nb[0], + V_padded->nb[2], + 0); + V_chunk = ggml_cont(ctx0, V_chunk); + + cur = ggml_mul_mat(ctx0, V_chunk, probs_padded); + // ungroup. + cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head); + // unpad + if (need_padding) { + cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } else { + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative position shifting is performed in the following block. + // Some more details on the operations performed below can be found here: + // https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift + { + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head_cur = rel_pos_scores->ne[2]; + + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head_cur, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + const int dw_pad = (hparams.n_conv_kernel - 1) / 2; + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, layer.conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + pstate.n_frames = cur->ne[1]; + + struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view)); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set mel input + { + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + // set attention mask + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor; + + std::vector<float> mask_data(n_q * n_k); + const float mask_value = -1e30f; + + if (n_k == n_q) { // full attention + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f; + } + } + } else { // local attention + const int att_left = n_k / 2; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int key = q - att_left + k; + mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value; + } + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set local attention skew mask + if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) { + const int n_k = local_mask->ne[0]; + const int n_q = local_mask->ne[1]; + + std::vector<float> mask_data(n_q * n_k); + const int window_size = n_k - n_q + 1; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int rel = k - q; + mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f; + } + } + ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set positional frequency + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector<float> freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + // set relative position offsets + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + std::vector<float> pos(window_size); + if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) { + for (int t = 0; t < window_size; ++t) { + pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t); + } + } else { + const int n_time = (window_size + 1) / 2; + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_ensure_encode_sched( + parakeet_context & pctx, + parakeet_state & pstate, + int n_audio_ctx) { + if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) { + return true; + } + + parakeet_sched_free(pstate.sched_encode); + + const int32_t prev_n_audio_ctx = pstate.n_audio_ctx; + pstate.n_audio_ctx = n_audio_ctx; + + const int subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor; + if (n_frames_max > pstate.enc_out->ne[1]) { + ggml_backend_buffer_free(pstate.enc_out_buffer); + pstate.enc_out_buffer = nullptr; + pstate.enc_out = nullptr; + + if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + } + + const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends, + [&]() { + return parakeet_build_graph_encode(pctx, pstate); + }); + + if (!ok) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + + pstate.sched_encode_n_audio_ctx = n_audio_ctx; + return true; +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // The 4 gates (i, f, o, c) are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + // b_h holds the folded ih+hh bias (see parakeet_model_load), so it is + // the only bias that needs to be added here. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_h); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // The gates are packed as [i, f, o, c] (reordered at convert time, see + // parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are + // contiguous and can be computed with a single ggml_sigmoid call. + struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0)); + ggml_format_name(ifo, "lstm_layer_%d_ifo", li); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].b_h, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out)); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = pstate.pred_out; + ggml_format_name(pred, "pred"); + + const int t_idx = batch.i_time[0]; + struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state, + (size_t) t_idx * pstate.enc_out->nb[1]); + ggml_format_name(enc_out, "enc_out_view"); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + const int64_t t_start_us = ggml_time_us(); + + { + auto & sched = pstate.sched_decode.sched; + + const int64_t t_build_start_us = ggml_time_us(); + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + pstate.t_predict_build_us += ggml_time_us() - t_build_start_us; + + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us; + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + const int64_t t_compute_start_us = ggml_time_us(); + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us; + } + + pstate.t_predict_us += ggml_time_us() - t_start_us; + pstate.n_predict++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timestamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.t_sample_us += ggml_time_us() - t_start_sample_us; + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + +static void log_mel_spectrogram_worker_thread( + mel_worker_params params, + const float * window_func, + const std::vector<float> & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector<float> fft_in(params.frame_size * 2, 0.0); + std::vector<float> fft_out(params.frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = params.ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (params.frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; + + const int window_pad_left = (params.frame_size - params.window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); + + // FFT + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += params.n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector<float> samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const size_t pad = (size_t)(frame_size / 2); + std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector<std::thread> workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + params, + window_func, + std::cref(samples_padded), + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + mel_params, + window_func, + samples_padded, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector<parakeet_vocab::id> tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast<unsigned char>(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = parakeet_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + { + const int n_audio_state = ctx->model.hparams.n_audio_state; + const int subsampl_factor = ctx->model.hparams.subsampling_factor; + const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor; + + if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) { + PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const size_t mem_enc_ctx = state->enc_out_buf.size(); + const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer); + PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0); + } + + // conv/encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encode(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx; + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + bool model_loaded = false; + try { + model_loaded = parakeet_model_load(loader, *ctx); + } catch (const std::exception & e) { + PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + ggml_backend_buffer_free(state->enc_out_buffer); + + parakeet_batch_free(state->batch); + + parakeet_sched_free(state->sched_encode); + parakeet_sched_free(state->sched_decode); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_predict = std::max(1, ctx->state->n_predict); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict); + PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict); + PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict); + PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict); + + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_predict_us = 0; + ctx->state->t_predict_build_us = 0; + ctx->state->t_predict_alloc_us = 0; + ctx->state->t_predict_compute_us = 0; + + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_predict = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + +} + +// Encode and decode the mel spectrogram already in state, without recomputing it. +static int parakeet_chunk_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + return parakeet_chunk(ctx, state, params, nullptr, 0); +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + state->result_all.clear(); + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + const int n_mel_total = state->mel.n_len; + const int n_audio_ctx = ctx->model.hparams.n_audio_ctx; + + if (n_mel_total <= n_audio_ctx) { + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + return parakeet_chunk_with_state(ctx, state, params); + } + + PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n", + __func__, n_mel_total, n_audio_ctx); + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__); + return -6; + } + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + + if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) { + PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n", + __func__, n_mel_total); + return -6; + } + + state->n_audio_ctx = n_mel_total; + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 100, params.progress_callback_user_data); + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * tok_str = parakeet_token_to_str(ctx, token_id); + if (tok_str) { + const bool is_first = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(tok_str, is_first); + } + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment seg; + seg.t0 = 0; + seg.t1 = state->n_frames; + seg.text = text; + seg.tokens = result_tokens; + state->result_all.push_back(std::move(seg)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) { + PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n", + __func__, state->n_audio_ctx); + return -6; + } + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + return -6; + } + } + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef PARAKEET_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 646f45f2ab7..74a5b142948 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -118,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh") + +# The following parakeet test require a real ggml-parakeet-tdt model to have +# been converted or downloaded: +# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models +# +# And also required more audio samples that are shipped by default. These can +# downloaded by running: +# $ make samples +function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0") + if (ARGC GREATER 4) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}") + endif() + + add_executable(${TEST_TARGET} ${TEST_SOURCE}) + target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples) + target_link_libraries(${TEST_TARGET} PRIVATE parakeet common) + target_compile_definitions(${TEST_TARGET} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}" + EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}" + TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD}) + + add_custom_target(run-${TEST_TARGET} + COMMAND $<TARGET_FILE:${TEST_TARGET}> + DEPENDS ${TEST_TARGET} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) +endfunction() + +add_parakeet_transcription_test( + test-parakeet-full-jfk + test-parakeet-full.cpp + samples/jfk.wav + tests/parakeet-expected-jfk-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-gb1 + test-parakeet-full.cpp + samples/gb1.wav + tests/parakeet-expected-gb1-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-diffusion + test-parakeet-full.cpp + samples/diffusion2023-07-03.flac + tests/parakeet-expected-diffusion-output.txt + 0.95) + diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..7d8992ec471 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology</span>", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace "<number> and a half" with "<number> point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/tests/parakeet-expected-diffusion-output.txt b/tests/parakeet-expected-diffusion-output.txt new file mode 100644 index 00000000000..9753a86953a --- /dev/null +++ b/tests/parakeet-expected-diffusion-output.txt @@ -0,0 +1 @@ +Hello and welcome to Diffusion. Sit back and relax while we stretch your brain with weird and wonderful science. I'm Ian Wolf. On this edition, Dr. Viv Robinson rewrites cosmology. But first up, here's news of two massive galaxies that might be older than the Big Bang. Galaxies too massive. Astronomers from the Swinburne University of Technology in Melbourne, using the James Webb Space Telescope, have observed six galaxies that formed in the universe's first 700 million years appear to be up to a hundred times more massive than our best theories say can possibly exist. Astronomer Ivo Labe and his colleagues wrote in his paper, adding up the stars in those galaxies, it would exceed the total amount of mass available in the universe at that time. There's too much mass and not enough time for it to get together. The galaxies must have had much longer than the 700 million years after the Big Bang that our standard model of the universe gives them, and the universe must have had more mass available, or galaxies must have formed differently than what we think. The Big Bang is currently thought to have started everything 13.77 billion years ago. And these galaxies, we're watching them at 0.77 billion years ago because they're so far away. Galaxies are thought to accumulate gas moved together by giant clumps of dark matter in their region. Generally, only about 10% of the gas in the galaxy ignites to make a star. For galaxies in the remotest parts of the universe where the gas is thin, it takes a long time to accumulate this much gas for this many stars. These six galaxies, however, have so many stars adding up to so much mass that all of the gas in each galaxy had to have become 100% converted into stars in the 700 million years since the universe started in the Big Bang. Under our current understanding, this is impossible. It suggests something in our understanding of the cosmos is wrong. Are we wrong about how to calculate astronomical masses, galaxy formation, dark matter, and the Big Bang and the age of the universe? An astronomer from the Cosmic Dawn Centre in Denmark used the James Webb telescope to look at closer galaxies, and then used the very high resolution of that telescope to calculate the mass more precisely with a different method, and found that these galaxies are three to ten times more massive than we previously thought. Applying this more accurate technique to the six galaxies that are 13 billion light years away would increase their mass, which makes it much worse than what we thought. The paper was titled A Population of Red Candidate Massive Galaxies, approximately 600 million years after the Big Bang, and was published in the journal Nature.com. We're brought to you across Australia on the Community Radio Network and podcast over the internet on www.diffusionradio.com Challenging Physics Newton said everything is either a particle or a wave. Faraday and Maxwell added fields. Einstein added space-time. Quantum physics says everything is made of quanta, which have the properties of both waves and particles, but is neither. Quantum mechanics has no explanation for gravity, and relativity doesn't account for the quantum world. There's a contradiction between our most basic explanations of the universe. Dr. Viv Robinson was the first person to create a physical explanation of Einstein's gravity in a paper published in the Journal of Physics Communications. He's made corrections to people's extensions of Einstein's mathematics and has a different way to interpret those mathematics that gives a different picture of the age of the universe and a different way of looking at how the physics works. From the standard model of quantum physics to Big Bang cosmology. Everything, including you and me, is made of light. It's a very big and very bold claim. I spoke to Dr. Viv Robinson via Zoom and began by asking him, what is the universe made of? The whole stuff of the universe, or entity. I won't call it items because one of them is absolutely nothing. The first thing to all the mass and all the energy is made up of photons. They're little packets of electromagnetic energy, postulated by Maxwell and Planck and proven by Einstein. They come in many different sizes, shapes, and which make that they make up all the mass and energy of the universe. The volume is made up by empty space, absolutely nothing. But it's the properties of the space that are important. And it does this through two of its properties, electric permittivity and magnetic permeability. And it's those properties which then transmit all of the fields. So that's really all it is. They're just the only two stars in a call because the photos are physical things, and space is just the absence of everything, but its property, its properties are what is important about it. And that's a little bit different to what you might hear from a quantum physics class where they talk about space being full of virtual particles coming into and out of existence so that it's not totally empty, or sometimes they say it's full of fields. The fields of every force is in there and things are coming up all the time. So if you go very fast, you'll interact with the fields, all the virtual particles, and you'll get radiation. Yes, well, uh the unfortunate part is that physics is doing exceedingly well under Newtonian mechanics and exceedingly well under Maxwell's mechanics. But as things get smaller and smaller, you get to a stage where things aren't continuous. I mean, Newton's work will anything that's continuous, but eventually you get to the stage where you know a droplet of water is fine, it has surface tension, evaporates, and you're left with one molecule of water. That doesn't behave the same as bulk water. Into that molecule you go hydrogen atoms and oxygen atoms, they behave nothing like water. And then you get, well, they're made of protons, neutrons, electrons, and they have completely different properties from bulk water. So quantum mechanics, things get quantized, and you get the smallest quantity you can get, and that has very, very different properties from the bulk. And what has happened in the past is that uh the uh early on in quantum mechanics and met men like Dirac and Schrdinger, they didn't know what an the structure was an electron was. Also, all they had to know, they knew it was it had wave properties. And so all they did was they attributed it to a way a wave property to it. Now, waves have the advantage over particles, you can manipulate them almost forever with all sorts of different transforms until you get the answer you want. And that gave some confidence to quantum mechanics guys that yes, waves work, and they've been using that forever, and all I'm saying, no, no, no, no, no. Everything is particles, and the particles have specific properties, and you can't manipulate those properties, or you can to a certain extent, but they are what they are, and it's when you know what those properties are that the whole quantum mechanics becomes much simpler. You don't need any of that uh foamy sort of stuff to get to explain whatever you want to explain. I mentioned that there are many different forms of photons, and photons are electromagnetic radiation with an electric field, saying on a magnetic field perpendicular to it, and the whole lot travels in the speed of light in the third dimension. There are many, many variations of that. So that that's fine for energy radiation. But how about matter particles? Well, matter particles are nothing more than photons of the appropriate wavelength making uh appropriate energy making two revolutions per wavelength. And when they do that, what holds what allows them to do that is that they rotate around the magnetic field. And suddenly, instead of in a linear photon, magnetic fields are open. When they rotate around the magnetic field, then the magnetic field of a particle is closed. And a closed magnetic field is much more stable than an open magnetic field, and that's why most of the universe, for example, when uh less about, I think the best estimate I've seen, one percent is radiation, the other 99% is photons struggling in circles, making two revolutions per wavelength. And it's for that that gives particles all their properties. Now, I may say this is a bit hairy-fairy, but it's been known for a long, long time that you get a particle and an antiparticle, you put them together, bing, two photons. At the same time, you can get a photon and goes and hit the target, bang, a particle and an antiparticle. Now that shows a relationship between the two that somehow lots of people missed. But what's the simplest relationship you can have? The simplest relationship is that a particle is a photon making two revolutions in one direction, an antiparticle is the same particle making two revolutions in the other direction. Put them together, they unlock. Because they have mass, they have this thing called angular momentum, which is a great Newtonian property. But because mathematicians sort of didn't know what an electron was, they called it a point particle. You can't have angular momentum with a point particle, so they call it spin and they wave all sorts of different things to make it seem as if they know what they're talking about. It's really just angular momentum. And that's the relationship between mass and energy. Energy is the photon zipping along at the speed of light. Mass is the same photon making two revolutions per wavelength. That's how they can interchange so easily. And that property gives particles all of their properties, including mass. And one of the things that Einstein did work out in 1905, those little what they called uh packets of radio of electromagnetic energy, he did work out that they carried momentum or carried inertia, they had momentum, they had mass. I don't know why people want to prove Einstein wrong. Photons have mass. Now I think the reason for this is that they think oh, Einstein's special relativity corrections, anything traveling at the speed of light, will have an infinite mass. The special relativity corrections only apply to photons which are spiraling. And that's just as um the reason for that is about as complicated as uh post Thagoras' theorem. And what he was at 300 BC or something like that, not difficult. And so photons themselves always travel at the speed of light. And so the rotating photons, photons that are rotating, are rotating also at the same speed of light. Well, that's one old hell of a gyroscope. And that is what gives particles a spin, that's why E equals mc squared, and it's all straightforward. There you go. Really? Well. So if we go back a little bit there where you're saying there's no wave nature, what about the double-slit experiment and other sorts of experiments that seem to show wave properties of particles other than photons? Particles um De Royal worked out in 1925 that if if photons, if um photons behave like particles, and particles to behave like photons, I agree with him, it's completely it's completely true. The actual nature of the rotating photon generates the de Broilie wavelength, and it has all the right properties. For me, and to me, Einstein's special and general relativity theories are relatively simple, so it may I may be talking a little bit out of line here. But the deuil wavelength is automatically generated by the particle as it moves. So it's not something that they hypothesize and don't know what occurs. They they hypothesized it, they measured it, but they don't know how it occurs. Well, yeah, it's quite it's fairly straightforward, but not at uh not not not at this level. What are the implications for this difference in understanding? So are there predictions that you would make that are different to the ones that people following the standard model would make? Oh, not the numbers of them, yeah. So probably the electron tunneling. Where electrons hit a barrier. That's got a very simple mechanical analog. I mean, the electrons are held in uh what you call a very taut field. Now, if you've got something coming up, you've got everything in a tight situation, you come something up banging it at this end, you can do it with billiard balls that'll transport through, and another one will knock out. So, what they call tunneling under this model, but in reality, what they call tunneling is just really a momentum exchange. So that's a little bit like one of those Newton cradles. Where you've got the balls on all attached by a string or a chain to a fulcrum over the top, and one will hit the other one and transfer the momentum to the other one without actually transferring itself. Yeah, you don't get electrons, you know, they have they have wave properties, but yes, but you won't get an electron uh tunneling the wave, the wave is in a very fixed position with respect to the uh electron. It's equal on either side of it. If their tunneling theory were correct, then the lower the energy of the electron, the longer its wavelength, therefore the easier it would be to tunnel. However, in the energy transfer one, the higher the energy, the greater probability it'll knock another electron out the other side. Or it's a simple experiment to do. Just increase the energy of uh an electron coming up to a barrier and see which ones go come out the other end first. Is anyone set up to do that? Oh, anyone could set up to do it. Well, a lot of laboratories could do it. And the so-called tunneling effect is what they use in all of the microelectronics systems. And they wouldn't, it wouldn't, it'd be a very, very simple exercise to carry that out. They may well have done it, and the mathematicians have turned around and added another factor. Yeah, it's a standard thing they do when they don't get the right answer, just add another factor. I can't do that. It's physical reality is physical reality. End of story. I guess that's something to look up and see if someone's done those experiments and and what they did with the results. I think there is I think I'm sure it has been done, and the result is that the higher the energy of the electron, the greater the probability of it emerging on the other side of the barrier. And on the very much bigger scale, are there differences in the way the universe looks for astronomy? Yeah, not as far as astronomy is concerned. What the astronomers see is what there is. No question about it. They're great, they're brilliant, as the astronomers, and most of the experimentalists are they're doing an exceedingly good job. The problem becomes in interpreting what they've seen. And when it comes to the whole universe, for example, it's all based on Einstein's theory of gravity. Well, it should be, but it's more advanced than Newton's inverse square, but for most practical purposes, uh Newton's inverse square works quite well. The two situations where it doesn't work, when the mass is so large, like the mass of the sun or the mass of the center of uh Sagittarius A with the planet or star S2 going around it. That's one situation. The reason why a planet uh or Mercury's orbit precesses in its direction of travel is simply that gravity, when mass is strong enough, gravity actually becomes weaker than inverse square. And that's one of the things you get when you solve Einstein's gravity theory accurately. It becomes weaker than inverse square. Now, when it's weak, if it's weaker than inverse square, Mercury travels a little bit closer to the Sun and is attracted by a slightly stronger force. So it'll arrive back at its perihelion point a little later, and it it'll um process in its direction of travel. And Newton pointed that out in 1687. So I don't know why they didn't sort of work it out correctly. But gravity is weaker than inverse square, is the solution to Einstein's gravity. The other thing is that when gravity is an infinite steady state universe under Newton's theory of gravity, inverse square, will collapse. The reason being that the relative to the universe density mass increases as r cubed, gravity decreases as r squared, so eventually you get to the stage where gravity just uh dominates mass and it collapses. But if gravity is weaker than inverse square, and I just tried to show you that Mercury is precessing orbit because the sun's gravity is weaker than inverse square, well, that applies to all gravity. There's nothing special about our sun, except that it's keeping all us alive on this. When you have an infinite steady-state universe, if gravity is weaker than inverse squares, its effect gets relatively weaker over long distances. And I'm talking typically uh 10 billion light years or something like that, maybe more. But that means an infinite steady-state universe won't collapse. That's a huge, huge difference. That's the biggest thing, mind you, what difference does it make to us here on Earth if uh if Bang's web has seen galaxies, fully formed galaxies 20 billion light years away, doesn't make a scrap of difference to us. But as far as understanding how the universe works, that mistake, and the simple the simple mistake that they the um all mathematicians were uh made, Einstein introduced approximations. He couldn't solve the gravity exactly himself. I have no problem solving his uh his gravity exactly. But he he uh introduced the approximation that one over one plus x approximately equals one minus x. You know, when x is ten to the minus seven or which is or ten to the minus eight, that's a good approximation. I mean you you just read his paper, he says so. And you read the mathematics, you don't even you could read the German version, look at the mathematics, and he says so, and you just work it out, and that was the difference. So, all of their exact solutions to Einstein's gravity, they took where he used the approximation, he derived the figure from one plus one over x, the equivalent of that, and then he rather than do that, he equated it to one minus x, which is which is true. You know, one plus one millionth is nine hundred and ninety millionth. Why they did it, I have no idea. Mind you, it'd be interesting to try and find out why. Uh I think it's if a mathematician of repute says one thing, and I I I will agree that uh on my first readings of Einstein's relativity theories, you think, oh my god, really? Could he understand that then? Then you get in and you start. It's not that difficult. And I think most of them had a solution. You know, somebody came up with a solution to Einstein's group, and everybody just followed it. And nobody, and this is the big thing that I always stress to everybody, don't take somebody's word for it. Go back and check the original yourself. I've seen a few times where people have just made terrible, terrible mistakes. But this would probably be the biggest one in the whole field of cosmology, sorry. Astronomy? You guys, great. Thanks, Uncle Sam, for providing us with all this information. That was part one of my interview with Dr. Viv Robinson. You heard Viv say that matter is made of photons moving in circles. Physicists took Einstein's approximations as gospel instead of using the exact solutions available with lather mathematics. Gravity changes to be weaker over distances, and the universe isn't expanding. Listen next week for part two. If you have any questions for Dr. Robinson, he'd love to answer them on the show. So send your questions to science at diffusionradio.com. If you're in Darlinghurst this Wednesday night, the 5th of July, I will be part of the lineup of scientists speaking at Future Science Talks at the East Village. Go to www.futurescience talks.com.au to grab a ticket and come up and say hello. And if you can't make it Wednesday night, I'll keep you posted on some future talks I'll be giving. And that's all from us this week on Diffusion. Are you a scientist, artist, biohacker, or maker who'd like to be interviewed about your work? Would your company like to sponsor diffusion? Send your contributions, opinions, helpful suggestions and donations to science at diffusionradio.com. That's science at diffusionradio.com. Please subscribe to the Diffusion Science Radio channel on youtube.com slash C slash Diffusion Radio and rate the show on iTunes and tell your friends. Follow me on Twitter at IanWorf. The news music was Rhinos Theme by Kevin McLeod of Incompitech.com. I produce diffusion, which is broadcast around Australia, to 28 stations on the community radio network, including Radio Blue Mountains 89.1 FM in New South Wales, 8CCC in Alice Springs and Tennant Creek, 2 MVR in Nambucker Valley, 3 MVR in the Malleigh Border Districts of Victoria and South Australia, City Park Radio 7LTN in Launcest and Tasmania, and 2XFM in Canberra. Diffusion is narrowcast on Indigo FM88 in Northeast Victoria. Diffusion is syndicated globally on astronomy.fm. Subscribe to the podcast on the diffusion website www.diffusionradio.com. That's www.diffusionradio.com and check the website for links, photos, and videos about this week's show. If you enjoyed the show, you can explore more than a thousand previous episodes archived on diffusionradio.com where the shows are labelled by keywords so you can focus in on the stories you want to hear. Make a donation through PayPal.me slash Ian Worf. Or join my patrons at patreon.com slash Diffusion Radio. I'm Ian Worf. Join us inside your audio device of choice for more science wondering next week on Diffusion Science Radio. Science is fun. It helps you to learn, to know, and to appreciate. When you study science, you make fun feel. diff --git a/tests/parakeet-expected-gb1-output.txt b/tests/parakeet-expected-gb1-output.txt new file mode 100644 index 00000000000..312ed1ce048 --- /dev/null +++ b/tests/parakeet-expected-gb1-output.txt @@ -0,0 +1 @@ +My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America. diff --git a/tests/parakeet-expected-jfk-output.txt b/tests/parakeet-expected-jfk-output.txt new file mode 100644 index 00000000000..ece35697ae8 --- /dev/null +++ b/tests/parakeet-expected-jfk-output.txt @@ -0,0 +1 @@ +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/tests/parakeet-verification.h b/tests/parakeet-verification.h new file mode 100644 index 00000000000..0e95610ba26 --- /dev/null +++ b/tests/parakeet-verification.h @@ -0,0 +1,110 @@ +#pragma once + +#include <algorithm> +#include <cassert> +#include <cctype> +#include <cstdio> +#include <fstream> +#include <iterator> +#include <string> +#include <vector> + +#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD +#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0 +#endif + +static std::string read_expected_transcription(const char * path) { + std::ifstream fin(path); + assert(fin.is_open()); + + std::string text( + (std::istreambuf_iterator<char>(fin)), + std::istreambuf_iterator<char>()); + + while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) { + text.pop_back(); + } + + return text; +} + +static std::vector<std::string> transcription_words(const std::string & text) { + std::vector<std::string> words; + std::string word; + + for (unsigned char ch : text) { + if (std::isalnum(ch)) { + word.push_back((char) std::tolower(ch)); + } else if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } + + if (!word.empty()) { + words.push_back(word); + } + + return words; +} + +static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) { + const std::vector<std::string> expected_words = transcription_words(expected); + const std::vector<std::string> actual_words = transcription_words(actual); + + if (expected_words.empty() && actual_words.empty()) { + return 1.0; + } + + if (expected_words.empty() || actual_words.empty()) { + return 0.0; + } + + std::vector<int> prev(actual_words.size() + 1, 0); + std::vector<int> cur (actual_words.size() + 1, 0); + + for (size_t i = 0; i < expected_words.size(); ++i) { + std::fill(cur.begin(), cur.end(), 0); + + for (size_t j = 0; j < actual_words.size(); ++j) { + if (expected_words[i] == actual_words[j]) { + cur[j + 1] = prev[j] + 1; + } else { + cur[j + 1] = std::max(prev[j + 1], cur[j]); + } + } + + prev.swap(cur); + } + + const int lcs = prev[actual_words.size()]; + return (2.0 * lcs) / (expected_words.size() + actual_words.size()); +} + +static bool verify_transcription(const std::string & expected, const std::string & actual) { + const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD; + + if (threshold >= 1.0) { + if (actual == expected) { + return true; + } + + fprintf(stderr, "\n\n"); + fprintf(stderr, "[Failed] Transcript mismatched\n"); + fprintf(stderr, "expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "actual:\n%s\n", actual.c_str()); + return false; + } + + const double similarity = transcription_lcs_similarity(expected, actual); + printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold); + + if (similarity >= threshold) { + return true; + } + + fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold); + fprintf(stderr, "Expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "Actual:\n%s\n", actual.c_str()); + return false; +} diff --git a/tests/run-tests.sh b/tests/run-tests.sh index ad2b8d3ec09..bc28314a704 100755 --- a/tests/run-tests.sh +++ b/tests/run-tests.sh @@ -21,13 +21,21 @@ cd `dirname $0` # Whisper models models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" ) +# Parakeet model variants +parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" ) + # list available models function list_models { printf "\n" - printf " Available models:" + printf " Available whisper models:" for model in "${models[@]}"; do printf " $model" done + printf "\n" + printf " Available parakeet models:" + for model in "${parakeet_models[@]}"; do + printf " parakeet-$model" + done printf "\n\n" } @@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then fi model=$1 -main="../build/bin/whisper-cli" threads="" if [ $# -eq 2 ]; then threads="-t $2" fi -if [ ! -f ../models/ggml-$model.bin ]; then - printf "Model $model not found. Aborting\n" +# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32") +is_parakeet=0 +parakeet_variant="" +if [[ $model == parakeet-* ]]; then + is_parakeet=1 + parakeet_variant="${model#parakeet-}" +fi +for v in "${parakeet_models[@]}"; do + if [[ $model == "$v" ]]; then + is_parakeet=1 + parakeet_variant="$v" + break + fi +done + +if [ $is_parakeet -eq 1 ]; then + main="../build/bin/parakeet-cli" + model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin" +else + main="../build/bin/whisper-cli" + model_path="../models/ggml-${model}.bin" +fi + +if [ ! -f $model_path ]; then + printf "Model $model not found ($model_path). Aborting\n" list_models exit 1 fi @@ -110,7 +140,11 @@ function run_lang() { fi fi - $main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null + if [ $is_parakeet -eq 1 ]; then + $main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null + else + $main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null + fi git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt @@ -120,7 +154,7 @@ function run_lang() { run_lang "en" "${urls_en[@]}" -if [[ $model != *.en* ]]; then +if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then run_lang "es" "${urls_es[@]}" run_lang "it" "${urls_it[@]}" run_lang "pt" "${urls_pt[@]}" diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..22ac4c20e31 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,101 @@ +#include "parakeet.h" +#include "common-whisper.h" +#include "parakeet-verification.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +struct test_state { + bool is_first = true; + std::string transcript; +}; + +void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; +} + +bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return true; +} + +bool abort_callback(void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return false; // just continue without aborting. +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + test_state * tstate = static_cast<test_state *>(user_data); + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf)); + + printf("%s", text_buf); + fflush(stdout); + + tstate->transcript += text_buf; + tstate->is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + test_state tstate; + params.new_token_callback = token_callback; + params.new_token_callback_user_data = &tstate; + bool progress_callback_called = false; + params.progress_callback = progress_callback; + params.progress_callback_user_data = &progress_callback_called; + bool encoder_begin_callback_called = false; + params.encoder_begin_callback = encoder_begin_callback; + params.encoder_begin_callback_user_data = &encoder_begin_callback_called; + bool abort_callback_called = false; + params.abort_callback = abort_callback; + params.abort_callback_user_data = &abort_callback_called; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + assert(progress_callback_called); + assert(encoder_begin_callback_called); + assert(abort_callback_called); + + const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH); + const bool transcript_matches = verify_transcription(expected, tstate.transcript); + + parakeet_free(pctx); + + if (!transcript_matches) { + return 1; + } + + printf("\nTest passed: parakeet_full succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +} From 0d14756929dc9f21ddccf6102bb783397b7a8f1b Mon Sep 17 00:00:00 2001 From: KITAITI Makoto <KitaitiMakoto@gmail.com> Date: Wed, 17 Jun 2026 13:42:09 +0900 Subject: [PATCH 285/289] ruby : add support for Parakeet (#3885) * Add Whisper::Parakeet::Params * Add tests for Parakeet::Params * Remove unused variabel * Add callbacks to Parakeet::Params * Group callback and user_data params * Undefine local macros * Define GetParakeetParams * Remove unused variable * Use ITERATE_CALLBACK_PARAMS * Use ITERATE_CALLBACK_PARAMS instead of ITERATE_USER_DATA_PARAMS * Fix memsize * Remove unnecessary macros * Simplify params registration * Define Parakeet * Add hook methods to Parakeet::Params * Fix typo * Check callback container in GetParakeetParams * Reduce if * Free parakeet_full_params * Implement Parakeet::Context#initialize * Add TestParakeetContext * Add Parakeet::Segment * Prevent double-free * Add Parakeet::Context#transcribe * Add Parakeet::Context#each_segment * Define Parakeet::Segment attributes * Define Parakeet::Segment#deconstruct_keys * Add tests for Parakeet::Segment#deconstruct_keys * Run Parakeet::Context#transcribe without GVL * Make it to abort for Parakeet * Add Parakeet.log_set * Define Parakeet::Token * Define Parakeet::Segment#each_token * Implement some hooks of Parakeet::Params * Convert int to VALUE * Implement hooks for Parakeet * Implement Parakeet::Context#full * Add tests for Parakeet::Context#full * Add Parakeet to RBS * Fix ruby_whisper_parakeet_params_free * Free ruby_whisper_parakeet_context * Add tests for hooks * Add Parakeet section to README * Add more attributes of Parakeet::Context * Add tests for Parakeet::Context's attributes * Update RBS * Register parakeet-tdt-0.6b-v3 * Narrow scope of log constants * Extract activate and deactivate of log_queue * Make start_log_callback_thread private * Don't call start_log_callback_thread unncecessarilly * Early return from log_queue_enqueue when not active * Gropu log_queue members * is_active -> is_open * Fix English * Share parakeet full body function * ruby_whisper_parakeet_abort_callback_user_data -> ruby_whisper_abort_callback_user_data * NULL check for callback containers * Fix Parakeet.log_set * Omit Parakeet tests on CI * Extract Whisper::LogSettable * Join log callback thread in a log queue function * Revert Join log callback thread in a log queue function * Extract output methods to modules * Move Parakeet init functions into init_parakeet() * Add output methods to Parakeet classes * Add Parakeet's output methods to RBS * Use Whisper::Output in RBS * Add LogSettable to RBS * Fix module Token -> class Token * Add Parakeet::Model * Add test for Parakeet::Model * Add Parakeet::Model to RBS * Move position of Parakeet::Model in RBS * Parakeet -> TestBase::Parakeet * Add Parakeet::Context#model in RBS * Add Whisper::Output * Fix nil check * Define ruby_whisper_parakeet_model_memsize * Fix order of declaration in ruby_whisper_parakeet_model_get_xxx * Define Parakeet.system_info_str * Add test for Parakeet.system_info_str * Add signature of Parakeet.system_info_str * Define Parakeet::VERSION * Add test for Parakeet::VERSION * Add signature of Parakeet::VERSION * Add Parakeet::Context::Params * Make Parakeet::Context.new accept Context::Params * Add test for Parakeet::Context.new with Context::Params * Update RBS * Remove params from Parakeet::Params which are moved from whisper_parakeet_full_params * Remove tests for removed params * Make Parakeet tests follow original behavior changes * Add Parakeet model shortcuts * Alloc token data in factory instead of alloc func * Fix variable name * Update RBS * Refactor log settable module * Use log settable for Whisper * Address deadlock * Make test follow change of log queue implementation * Refactor to make abort callback use the same way to parakeet's way * Remove redundant structs * Fix test name * Fix README * Add missing parallel transcription * Fix test for parakeet info * Remove removed params * Wait for logs dequeued * Fix instance variable name * Load etc feature * Remove unnecessary comment * Remove unnecessary thread safety check * Remove outdated comment * Skip downloading model if cache exists * Change Hugging Face URI for Parakeet models * Bump required Ruby version to 3.3 * Fix English --- .github/workflows/bindings-ruby.yml | 2 +- bindings/ruby/README.md | 31 + bindings/ruby/Rakefile | 17 +- bindings/ruby/ext/ruby_whisper.c | 116 ++-- bindings/ruby/ext/ruby_whisper.h | 135 ++++- bindings/ruby/ext/ruby_whisper_context.c | 51 +- bindings/ruby/ext/ruby_whisper_log_queue.c | 180 ++++++ bindings/ruby/ext/ruby_whisper_log_settable.h | 47 ++ bindings/ruby/ext/ruby_whisper_parakeet.c | 49 ++ .../ruby/ext/ruby_whisper_parakeet_context.c | 304 ++++++++++ .../ruby_whisper_parakeet_context_params.c | 117 ++++ .../ruby/ext/ruby_whisper_parakeet_model.c | 84 +++ .../ruby/ext/ruby_whisper_parakeet_params.c | 548 ++++++++++++++++++ .../ruby/ext/ruby_whisper_parakeet_segment.c | 157 +++++ .../ruby/ext/ruby_whisper_parakeet_token.c | 188 ++++++ .../ext/ruby_whisper_parakeet_transcribe.cpp | 58 ++ bindings/ruby/ext/ruby_whisper_params.c | 117 ++-- bindings/ruby/ext/ruby_whisper_segment.c | 12 +- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 62 +- bindings/ruby/lib/whisper/context.rb | 15 - bindings/ruby/lib/whisper/log_settable.rb | 36 ++ bindings/ruby/lib/whisper/model/uri.rb | 14 +- bindings/ruby/lib/whisper/output.rb | 74 +++ bindings/ruby/lib/whisper/segment.rb | 58 -- bindings/ruby/sig/whisper.rbs | 369 +++++++++++- bindings/ruby/test/helper.rb | 2 + bindings/ruby/test/test_callback.rb | 1 + bindings/ruby/test/test_parakeet.rb | 28 + bindings/ruby/test/test_parakeet_callback.rb | 107 ++++ bindings/ruby/test/test_parakeet_context.rb | 116 ++++ .../ruby/test/test_parakeet_context_params.rb | 24 + bindings/ruby/test/test_parakeet_model.rb | 21 + bindings/ruby/test/test_parakeet_params.rb | 78 +++ bindings/ruby/test/test_parakeet_segment.rb | 42 ++ bindings/ruby/test/test_parakeet_token.rb | 73 +++ bindings/ruby/test/test_vad_segment.rb | 2 +- bindings/ruby/test/test_whisper.rb | 1 + bindings/ruby/whispercpp.gemspec | 2 +- 38 files changed, 3005 insertions(+), 333 deletions(-) create mode 100644 bindings/ruby/ext/ruby_whisper_log_queue.c create mode 100644 bindings/ruby/ext/ruby_whisper_log_settable.h create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_context.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_context_params.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_model.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_params.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_segment.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_token.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp delete mode 100644 bindings/ruby/lib/whisper/context.rb create mode 100644 bindings/ruby/lib/whisper/log_settable.rb create mode 100644 bindings/ruby/lib/whisper/output.rb delete mode 100644 bindings/ruby/lib/whisper/segment.rb create mode 100644 bindings/ruby/test/test_parakeet.rb create mode 100644 bindings/ruby/test/test_parakeet_callback.rb create mode 100644 bindings/ruby/test/test_parakeet_context.rb create mode 100644 bindings/ruby/test/test_parakeet_context_params.rb create mode 100644 bindings/ruby/test/test_parakeet_model.rb create mode 100644 bindings/ruby/test/test_parakeet_params.rb create mode 100644 bindings/ruby/test/test_parakeet_segment.rb create mode 100644 bindings/ruby/test/test_parakeet_token.rb diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 80a243e4c98..8cdb7a810f7 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -27,6 +27,6 @@ jobs: steps: - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: - ruby-version: '3.2' + ruby-version: '3.3' - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 07b81830c58..7f6b7d92c09 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -396,6 +396,37 @@ whisper .full(Whisper::Params.new, samples) ``` +### Parakeet ### + +whispercpp gem now supports NVIDIA's ASR model Parakeet. + +If you want to use Parakeet instead of Whisper, the API should feel familiar. +In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way. + +```ruby +require "whisper" + +# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem. +Parakeet = Whisper::Parakeet + +parakeet = Parakeet::Context.new("path/to/model") + +params = Parakeet::Params.new( + no_context: true +) + +parakeet + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}" + end +``` + +The main differences are: + +* Namespace is `Whisper::Parakeet`. +* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks. + Custom context params --------------------- diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 7b521b3bdfa..2327651a06a 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -84,6 +84,21 @@ else end end +TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin" +TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin")) +TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d") +directory TEST_PARAKEET_MODEL_DIR +if File.exist? TEST_PARAKEET_MODEL_SRC + file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t| + symlink t.source, t.name + end +else + require "open-uri" + file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t| + File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read + end +end + TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| chdir "test/jfk_reader" do @@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| end CLEAN.include TEST_MEMORY_VIEW -task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] +task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL] diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 56fceb1c894..7941b1a99dd 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,19 +1,29 @@ #include "ruby_whisper.h" VALUE mWhisper; +VALUE mLogSettable; VALUE mVAD; +VALUE mParakeet; VALUE cContext; VALUE cParams; VALUE cVADContext; VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; +VALUE cParakeetContext; +VALUE cParakeetContextParams; +VALUE cParakeetParams; +VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; VALUE cToken; VALUE cModel; +VALUE mOutputContext; +VALUE mOutputSegment; + ID id_to_s; ID id_call; ID id___method__; @@ -27,9 +37,11 @@ ID id_pre_converted_models; ID id_coreml_compiled_models; ID id_cache; ID id_n_processors; - -static bool is_log_callback_finalized = false; -static bool is_ruby_log_callback_present = false; +ID id_extended; +ID id_start_log_callback_thread; +ID id_log_callback_thread; +ID id_alive_p; +ID id_join; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void init_ruby_whisper_vad_context(VALUE *mVAD); extern void init_ruby_whisper_vad_segment(VALUE *mVAD); extern void init_ruby_whisper_vad_segments(VALUE *mVAD); +extern void init_ruby_whisper_parakeet(VALUE *mWhisper); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +static ruby_whisper_log_queue whisper_log_queue; + +LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set) + /* * call-seq: * lang_max_id -> Integer @@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) { return rb_str_new2(whisper_print_system_info()); } -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -typedef struct { - int level; - const char * buffer; -} call_log_callbacks_args; - -static void* -call_log_callbacks(void *v_args) { - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { - return NULL; - } - - call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; - VALUE user_data = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); - - return NULL; -} - -static void -ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - if (!is_ruby_log_callback_present) { - return; - } - - call_log_callbacks_args args = { - level, - buffer, - }; - if (ruby_thread_has_gvl_p()) { - call_log_callbacks((void *)&args); - } else { - rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); - } -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - if (!NIL_P(log_callback)) { - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - } - - if (NIL_P(log_callback)) { - whisper_log_set(NULL, NULL); - is_ruby_log_callback_present = false; - } else { - whisper_log_set(ruby_whisper_log_callback, NULL); - is_ruby_log_callback_present = true; - } - - return Qnil; -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -189,9 +133,19 @@ void Init_whisper() { id_coreml_compiled_models = rb_intern("coreml_compiled_models"); id_cache = rb_intern("cache"); id_n_processors = rb_intern("n_processors"); + id_extended = rb_intern("extended"); + id_start_log_callback_thread = rb_intern("start_log_callback_thread"); + id_log_callback_thread = rb_intern("@log_callback_thread"); + id_alive_p = rb_intern("alive?"); + id_join = rb_intern("join"); mWhisper = rb_define_module("Whisper"); + rb_require("whisper/log_settable"); + mLogSettable = rb_path2class("Whisper::LogSettable"); mVAD = rb_define_module_under(mWhisper, "VAD"); + rb_require("whisper/output"); + mOutputContext = rb_path2class("Whisper::Output::Context"); + mOutputSegment = rb_path2class("Whisper::Output::Segment"); rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version())); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); @@ -222,8 +176,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); + + LOG_SETTABLE_INIT(whisper_log_queue, mWhisper) cContext = init_ruby_whisper_context(&mWhisper); init_ruby_whisper_context_params(&cContext); @@ -236,8 +190,10 @@ void Init_whisper() { init_ruby_whisper_vad_segment(&mVAD); init_ruby_whisper_vad_segments(&mVAD); init_ruby_whisper_vad_context(&mVAD); + init_ruby_whisper_parakeet(&mWhisper); - rb_require("whisper/context"); - rb_require("whisper/segment"); rb_require("whisper/model/uri"); + + rb_include_module(cContext, mOutputContext); + rb_include_module(cSegment, mOutputSegment); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index ba4d8b6fbcc..10e90674953 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,8 +5,12 @@ #include <ruby/version.h> #include <ruby/util.h> #include <ruby/thread.h> +#include <ruby/thread_native.h> +#include <ruby/atomic.h> #include <ruby/memory_view.h> #include "whisper.h" +#include "parakeet.h" +#include "ruby_whisper_log_settable.h" #if RUBY_API_VERSION_MAJOR < 4 // Exists but not declared as public API @@ -20,13 +24,28 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; -typedef struct { - VALUE *context; - VALUE user_data; - VALUE callback; - VALUE callbacks; - bool is_interrupted; -} ruby_whisper_abort_callback_container; +typedef struct ruby_whisper_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; +} ruby_whisper_abort_callback_user_data; + +typedef struct ruby_whisper_log { + enum ggml_log_level level; + char *text; + size_t length; + size_t capacity; +} ruby_whisper_log; + +typedef struct ruby_whisper_log_queue { + rb_nativethread_lock_t lock; + rb_nativethread_cond_t cond; + bool is_open; + + size_t head; + size_t tail; + size_t size; + ruby_whisper_log *logs; +} ruby_whisper_log_queue; typedef struct { struct whisper_context *context; @@ -42,7 +61,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; @@ -84,6 +103,63 @@ typedef struct parsed_samples_t { bool memview_exported; } parsed_samples_t; +typedef struct { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} ruby_whisper_full_args; + +typedef struct ruby_whisper_full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} ruby_whisper_full_parallel_args; + +typedef struct { + struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; +} ruby_whisper_parakeet_params; + +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + +typedef struct { + struct parakeet_context *context; +} ruby_whisper_parakeet_context; + +typedef struct { + VALUE context; + int index; +} ruby_whisper_parakeet_segment; + +typedef struct { + parakeet_token_data *token_data; + VALUE text; +} ruby_whisper_parakeet_token; + +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + +extern ID id_extended; +extern ID id_log_callback_thread; +extern ID id_start_log_callback_thread; +extern ID id_alive_p; +extern ID id_join; +extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text); +extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue); + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -120,4 +196,47 @@ typedef struct parsed_samples_t { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetParams(obj, rwpp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ + if (!(rwpp)->new_segment_callback_container || \ + !(rwpp)->new_token_callback_container || \ + !(rwpp)->progress_callback_container || \ + !(rwpp)->encoder_begin_callback_container || \ + !(rwpp)->abort_callback_container) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetSegment(obj, rwps) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \ + if (!(rwps)->context) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetToken(obj, rwpt) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \ + if (!(rwpt)->token_data) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 26058fc07e6..9e5fc33e726 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; @@ -38,21 +38,6 @@ typedef struct fill_samples_args { int n_samples; } fill_samples_args; -typedef struct full_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; -} full_args; - -typedef struct full_parallel_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; - int n_processors; -} full_parallel_args; - typedef struct full_without_gvl_args { struct whisper_context *context; struct whisper_full_params *params; @@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args { } full_parallel_without_gvl_args; typedef struct full_ubf_args { - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_abort_callback_user_data *abort_callback_user_data; } full_ubf_args; static void @@ -379,7 +364,7 @@ fill_samples(VALUE rb_args) return Qnil; } -struct parsed_samples_t +parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples) { bool memview_available = rb_memory_view_available_p(*samples); @@ -480,20 +465,24 @@ full_ubf(void *rb_args) { full_ubf_args *args = (full_ubf_args *)rb_args; - args->abort_callback_container->is_interrupted = true; + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); } -static VALUE +VALUE full_body(VALUE rb_args) { - full_args *args = (full_args *)rb_args; + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, 1); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); struct full_without_gvl_args full_without_gvl_args = { rw->context, @@ -503,7 +492,7 @@ full_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_without_gvl_args.result); @@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) VALUE n_samples = argc == 2 ? Qnil : argv[2]; struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - full_args args = { + ruby_whisper_full_args args = { &self, &argv[0], parsed.samples, @@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args) return NULL; } -static VALUE +VALUE full_parallel_body(VALUE rb_args) { - full_parallel_args *args = (full_parallel_args *)rb_args; + ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, args->n_processors); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { rw->context, @@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_parallel_without_gvl_args.result); @@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) break; } struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - const full_parallel_args args = { + const ruby_whisper_full_parallel_args args = { &self, &argv[0], parsed.samples, diff --git a/bindings/ruby/ext/ruby_whisper_log_queue.c b/bindings/ruby/ext/ruby_whisper_log_queue.c new file mode 100644 index 00000000000..6558a339c6f --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_queue.c @@ -0,0 +1,180 @@ +#include "ruby_whisper.h" + +#define LOG_QUEUE_CAPACITY 256 +#define LOG_DEFAULT_CAPACITY 1024 + +void +ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_initialize(&log_queue->lock); + rb_native_cond_initialize(&log_queue->cond); + log_queue->head = 0; + log_queue->tail = 0; + log_queue->size = 0; + log_queue->is_open = true; + log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY); + for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) { + // we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL + // this doesn't be freed because log queue lives until the end of process + char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY); + if (!slot) { + rb_raise(rb_eRuntimeError, "Could not allocate memory for log text"); + } + ruby_whisper_log log = { + 0, + slot, + 0, + LOG_QUEUE_CAPACITY, + }; + log_queue->logs[i] = log; + } +} + +void +ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = true; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +void +ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = false; + rb_native_cond_broadcast(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static size_t +calc_enough_cap(size_t len) +{ + size_t quot = len / LOG_DEFAULT_CAPACITY; + size_t rem = len % LOG_DEFAULT_CAPACITY; + + return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY; +} + +void +ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + if (!log_queue->is_open) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + + size_t len = strlen(text); + ruby_whisper_log *log = &log_queue->logs[log_queue->head]; + if (len > log->capacity) { + size_t new_cap = calc_enough_cap(len); + // we cannot call Ruby API like REALLOC_N because this function is called without GVL + char *slot = realloc(log->text, new_cap); + if (!slot) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + log->text = slot; + log->capacity = new_cap; + } + // we cannot call Ruby API like MEMCPY because this function is called without GVL + memcpy(log->text, text, sizeof(char) * len); + log->length = len; + log->level = level; + log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY; + bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY; + log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1; + if (is_full) { + log_queue->tail = log_queue->head; + } + + rb_native_cond_signal(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static void* +ruby_whisper_log_queue_wait(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_wait(&log_queue->cond, &log_queue->lock); + rb_nativethread_lock_unlock(&log_queue->lock); + + return NULL; +} + +static void +ruby_whisper_log_queue_wait_ubf(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_broadcast(&log_queue->cond); +} + +typedef struct { + enum ggml_log_level level; + size_t length; + char *text; +} log_snapshot; + +VALUE +ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue) +{ + log_snapshot logs[LOG_QUEUE_CAPACITY]; + + rb_nativethread_lock_lock(&log_queue->lock); + + while (log_queue->size == 0 && log_queue->is_open) { + rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue); + rb_nativethread_lock_lock(&log_queue->lock); + } + + if (log_queue->size == 0 && !log_queue->is_open) { + rb_native_cond_broadcast(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); + return Qnil; + } + + size_t size = log_queue->size; + ruby_whisper_log *log; + size_t i; + for (i = 0; i < size; i++) { + log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY]; + logs[i].level = log->level; + logs[i].length = log->length; + char *text = malloc(log->length); + if (!text) { + logs[i].text = NULL; + continue; + } + logs[i].text = text; + memcpy(logs[i].text, log->text, log->length); + } + log_queue->size = 0; + log_queue->tail = log_queue->head; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); + + VALUE rb_logs = rb_ary_new2(size); + VALUE rb_text; + for (i = 0; i < size; i++) { + if (!logs[i].text) { + continue; + } + rb_text = rb_str_new(logs[i].text, logs[i].length); + free(logs[i].text); + rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text)); + } + + return rb_logs; +} diff --git a/bindings/ruby/ext/ruby_whisper_log_settable.h b/bindings/ruby/ext/ruby_whisper_log_settable.h new file mode 100644 index 00000000000..b98fbac826b --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_settable.h @@ -0,0 +1,47 @@ +#ifndef RUBY_WHISPER_LOG_SETTABLE_H +#define RUBY_WHISPER_LOG_SETTABLE_H + +#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \ + static VALUE \ + ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \ + { \ + return ruby_whisper_log_queue_drain(&log_queue); \ + } \ + static void \ + ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \ + { \ + ruby_whisper_log_queue_enqueue(&log_queue, level, text); \ + } \ + static VALUE \ + ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \ + { \ + rb_iv_set(self, "@log_callback", log_callback); \ + rb_iv_set(self, "@log_callback_user_data", user_data); \ + if (NIL_P(log_callback)) { \ + log_set(NULL, NULL); \ + } else { \ + ruby_whisper_log_queue_open(&log_queue); \ + rb_funcall((mod), id_start_log_callback_thread, 0); \ + log_set(ruby_whisper_##log_queue##_log_callback, NULL); \ + } \ + return Qnil; \ + } \ + static void \ + ruby_whisper_##log_queue##_end_proc(VALUE args) \ + { \ + ruby_whisper_log_queue_close(&log_queue); \ + VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \ + if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \ + rb_funcall(log_callback_thread, id_join, 0); \ + } \ + } + +#define LOG_SETTABLE_INIT(log_queue, mod) \ + ruby_whisper_log_queue_initialize(&log_queue); \ + rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \ + rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \ + rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \ + rb_extend_object(mod, mLogSettable); \ + rb_funcall(mLogSettable, id_extended, 1, mod); + +#endif diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c new file mode 100644 index 00000000000..d69369401d0 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -0,0 +1,49 @@ +#include "ruby_whisper.h" +#include <stdio.h> +#include <unistd.h> + +extern VALUE mParakeet; +extern VALUE mLogSettable; +extern VALUE cParakeetContext; +extern VALUE cParakeetSegment; +extern VALUE mOutputContext; +extern VALUE mOutputSegment; + +extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); + +static ruby_whisper_log_queue parakeet_log_queue; + +LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set) + +static VALUE +ruby_whisper_parakeet_s_system_info_str(VALUE self) +{ + return rb_str_new2(parakeet_print_system_info()); +} + +void +init_ruby_whisper_parakeet(VALUE *mWhisper) +{ + mParakeet = rb_define_module_under(*mWhisper, "Parakeet"); + + rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version())); + + LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet) + + rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0); + + init_ruby_whisper_parakeet_params(&mParakeet); + init_ruby_whisper_parakeet_token(&mParakeet); + init_ruby_whisper_parakeet_segment(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); + init_ruby_whisper_parakeet_model(&mParakeet); + + rb_include_module(cParakeetContext, mOutputContext); + rb_include_module(cParakeetSegment, mOutputSegment); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c new file mode 100644 index 00000000000..b4a2fc5c4b7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -0,0 +1,304 @@ +#include "ruby_whisper.h" + +#define ITERATE_SEGMENT_ATTRS(ITERATOR) \ + ITERATOR(get_segment_t0, LONG) \ + ITERATOR(get_segment_t1, LONG) \ + ITERATOR(get_segment_text, STRING) \ + ITERATOR(n_tokens, INT) + +#define ITERATE_TOKEN_ATTRS(ITERATOR) \ + ITERATOR(get_token_text, STRING) \ + ITERATOR(get_token_id, INT) \ + ITERATOR(get_token_p, FLOAT) + +#define VAL_FROM_LONG(v) LONG2NUM(v) +#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v) +#define VAL_FROM_INT(v) INT2NUM(v) +#define VAL_FROM_FLOAT(v) DBL2NUM(v) +#define READER(type) VAL_FROM_##type + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_new; + +extern VALUE cParakeetContext; +extern VALUE eError; + +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE rb_parsed_args); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); +extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); + +static void +ruby_whisper_parakeet_context_free(void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (rwpc->context) { + parakeet_free(rwpc->context); + rwpc->context = NULL; + } + xfree(rwpc); +} + +static size_t +ruby_whisper_parakeet_context_memsize(const void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (!rwpc) { + return 0; + } + size_t size = sizeof(*rwpc); + return size; +} + +const rb_data_type_t ruby_whisper_parakeet_context_type = { + "ruby_whisper_parakeet_context", + {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_context_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context *rwpc; + + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + rwpc->context = NULL; + + return obj; +} + +typedef struct { + struct parakeet_context **context; + char *model_path; + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_init_args; + +static void* +ruby_whisper_parakeet_context_init_without_gvl(void *args) +{ + ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); + return NULL; +} + +static VALUE +ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; + + rb_scan_args(argc, argv, "11", &model_path, &context_params); + TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + + model_path = ruby_whisper_normalize_model_path(model_path); + if (!rb_respond_to(model_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); + } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } + ruby_whisper_parakeet_context_init_args init_args = { + &rwpc->context, + StringValueCStr(model_path), + params, + }; + rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); + if (rwpc->context == NULL) { + rb_raise(rb_eRuntimeError, "Failed to load model"); + } + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_context_full_n_segments(VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + return INT2NUM(parakeet_full_n_segments(rwpc->context)); +} + +#define DEF_SEGMENT_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \ + } + +ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR) + +#define DEF_TOKEN_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \ + } + +ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR) + +static VALUE +ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token)); + + return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data); +} + +static VALUE +ruby_whisper_parakeet_context_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + const int n_segments = parakeet_full_n_segments(rwpc->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(ruby_whisper_parakeet_segment_init(self, i)); + } + + return self; +} + +typedef struct { + struct parakeet_context *context; + struct parakeet_full_params *params; + float *samples; + int n_samples; + int result; +} parakeet_full_without_gvl_args; + +static void* +parakeet_full_without_gvl(void *rb_args) +{ + parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args; + args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} parakeet_full_ubf_args; + +static void +parakeet_full_ubf(void *rb_args) +{ + parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +ruby_whisper_parakeet_context_full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(*args->context, rwpc); + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(*args->params, rwpp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data); + + parakeet_full_without_gvl_args full_without_gvl_args = { + rwpc->context, + &rwpp->params, + args->samples, + args->n_samples, + 0 + }; + parakeet_full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args); + + return INT2NUM(full_without_gvl_args.result); +} + +static VALUE +ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + +VALUE +init_ruby_whisper_parakeet_context(VALUE *mParakeet) +{ + cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); + + rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); + + rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); + rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); + rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); + rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); + rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); + +#define REGISTER_SEGMENT_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1); + + ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR) + +#define REGISTER_TOKEN_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); + + ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 00000000000..38bd6d57ce1 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 00000000000..dce43c688e7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,84 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (!NIL_P(rwpm->context)) { + rb_gc_mark(rwpm->context); + } +} + +static size_t +ruby_whisper_parakeet_model_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_model); +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetModel(self, rwpm); \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c new file mode 100644 index 00000000000..076e2a0cdfb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -0,0 +1,548 @@ +#include "ruby_whisper.h" + +#define ITERATE_PARAMS(ITERATOR) \ + ITERATOR(n_threads, INT) \ + ITERATOR(offset_ms, INT) \ + ITERATOR(duration_ms, INT) \ + ITERATOR(no_context, BOOL) \ + ITERATOR(audio_ctx, INT) + +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + +#define ITERATE_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATOR(abort_callback) + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data, + ITERATE_PARAMS(DEF_IDX) + ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK) + ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA) + + RUBY_WHISPER_PARAKEET_NUM_PARAMS +}; + +#define VAL_TO_INT(v) (NUM2INT(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) + +extern VALUE cParakeetParams; +extern ID id_call; + +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef VALUE (*param_writer_t)(VALUE, VALUE); +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; +} + +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) +{ +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + +static void +ruby_whisper_parakeet_params_mark(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define MARK_CONTAINER(name) \ + if (rwpp->name##_container) { \ + ruby_whisper_callback_container_mark(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(MARK_CONTAINER) +} + +static void +ruby_whisper_parakeet_params_free(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define FREE_CONTAINER(name) \ + if (rwpp->name##_container) { \ + xfree(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(FREE_CONTAINER) + + xfree(rwpp); +} + +static size_t +ruby_whisper_parakeet_params_memsize(const void *p) +{ + const struct ruby_whisper_parakeet_params *params = p; + if (!params) { + return 0; + } + return sizeof(ruby_whisper_parakeet_params); +} + +const rb_data_type_t ruby_whisper_parakeet_params_type = { + "ruby_whisper_parakeet_params", + {ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,}, + 0, 0, + 0 +}; + +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type +#define DEF_PARAM_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return READER(type)(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->params.name = WRITER(type)(val); \ + return val; \ + } + +#define DEF_CALLBACK_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \ + return val; \ + } + +#define DEF_HOOK(name, data) \ + static VALUE \ + ruby_whisper_parakeet_params_on_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + const VALUE blk = rb_block_proc(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ + } \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ + return Qnil; \ + } + +ITERATE_PARAMS(DEF_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_params *rwpp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + return obj; +} + +static VALUE +ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_params *rwpp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + +#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate(); + + ITERATE_CALLBACK_PARAMS(INIT_CONTAINER) + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); + + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +void +init_ruby_whisper_parakeet_params(VALUE *mParakeet) +{ + cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject); + rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate); + + rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); + + int i = 0; +#define REGISTER_PARAM(name) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ + rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ + rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \ + i++; + +#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name) +#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name) +#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data) + + ITERATE_PARAMS(REGISTER_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) + +#define REGISTER_HOOK(name, data) \ + rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); + + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c new file mode 100644 index 00000000000..b1e81ba930c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -0,0 +1,157 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(start_time, t0, TIME) \ + ITERATOR(end_time, t1, TIME) \ + ITERATOR(text, text, STRING) + +enum { +#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, +}; + +#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10)) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) +#define READER(type) VAL_FROM_##type +#define DEF_ATTR(rb_name, c_name, type) \ + static VALUE \ + ruby_whisper_parakeet_get_##rb_name(VALUE self) \ + { \ + ruby_whisper_parakeet_segment *rwps; \ + GetParakeetSegment(self, rwps); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwps->context, rwpc); \ + return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \ + } + +extern ID id___method__; +extern ID id_to_enum; +extern VALUE cParakeetSegment; +extern VALUE sym_start_time; +extern VALUE sym_end_time; +extern VALUE sym_text; +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); + +static void +rb_whisper_parakeet_segment_mark(void *p) +{ + ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p; + rb_gc_mark(rwps->context); +} + +static size_t +ruby_whisper_parakeet_segment_memsize(const void *p) +{ + const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p; + if (!rwps) { + return 0; + } + return sizeof(*rwps); +} + +static const rb_data_type_t ruby_whisper_parakeet_segment_type = { + "ruby_whisper_parakeet_segment", + {rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_segment_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_segment *rwps; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); +} + +VALUE +ruby_whisper_parakeet_segment_init(VALUE context, int index) +{ + ruby_whisper_parakeet_segment *rwps; + + const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment); + TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); + rwps->context = context; + rwps->index = index; + + return segment; +} + +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_segment_each_token(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); + for (int i = 0; i < n_tokens; i++) { + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); + } + + return self; +} + +static VALUE +ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + VALUE hash = rb_hash_new(); + long n_keys; + if (NIL_P(keys)) { + keys = rb_ary_new3( + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, + sym_start_time, + sym_end_time, + sym_text + ); + n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) { + return hash; + } + } + for (int i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, c_name, type) \ + if (key == sym_##rb_name) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \ + } + + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_segment(VALUE *mParakeet) +{ + cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject); + + rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate); + +#define REGISTER_ATTR(rb_name, c_name, type) \ + rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0); + rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c new file mode 100644 index 00000000000..a00b7ae1cbb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -0,0 +1,188 @@ +#include "ruby_whisper.h" + +#define ITERATE_MEMBERS(ITERATOR) \ + ITERATOR(id, id, id, id, INT) \ + ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \ + ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \ + ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \ + ITERATOR(probability, probability, p, p, FLOAT) \ + ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \ + ITERATOR(start_time, start_time, start_time, t0, TIME) \ + ITERATOR(end_time, end_time, end_time, t1, TIME) \ + ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL) + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(text, text, text, text, STRING) + +enum { +#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name, + + ITERATE_MEMBERS(DEF_IDX) + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, +}; + +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_FROM_FLOAT(v) (DBL2NUM(v)) +#define VAL_FROM_TIME(v) (LONG2NUM(v * 10)) +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) + +#define READER(type) VAL_FROM_##type +#define MEMBER_NAME(name) name +#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \ + } + +#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return rwpt->p_name; \ + } + +VALUE cParakeetToken; + +#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key; + +ITERATE_MEMBERS(DEC_ATTR_SYMS) +ITERATE_ATTRS(DEC_ATTR_SYMS) + +static void +ruby_whisper_parakeet_token_mark(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + rb_gc_mark(rwpt->text); +} + +static void +ruby_whisper_parakeet_token_free(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (rwpt->token_data) { + xfree(rwpt->token_data); + rwpt->token_data = NULL; + } + xfree(rwpt); +} + +static size_t +ruby_whisper_parakeet_token_memsize(const void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (!rwpt) { + return 0; + } + size_t size = sizeof(*rwpt); + if (rwpt->token_data) { + size += sizeof(*rwpt->token_data); + } + + return size; +} + +static const rb_data_type_t ruby_whisper_parakeet_token_type = { + "ruby_whisper_parakeet_token", + {ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_token_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_token *rwpt; + VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = NULL; + rwpt->text = Qnil; + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) +{ + const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); + ruby_whisper_parakeet_token *rwpt; + TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = ALLOC(parakeet_token_data); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + +ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_token *rwpt; + GetParakeetToken(self, rwpt); + + VALUE hash = rb_hash_new(); + long n_keys = 0; + + if (NIL_P(keys)) { + VALUE attrs[] = { +#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key, + + ITERATE_MEMBERS(LIST_SYMS) + ITERATE_ATTRS(LIST_SYMS) + }; + keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs); + n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) { + return hash; + } + } + for (long i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \ + if (key == sym_##s_key) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \ + } + + ITERATE_MEMBERS(CHECK_AND_SET_KEY) + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_token(VALUE *mParakeet) +{ + cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject); + rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate); + +#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \ + sym_##s_key = ID2SYM(rb_intern(#s_key)); \ + rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0); + + ITERATE_MEMBERS(REGISTER_ATTR) + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 00000000000..c4deccce84a --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,58 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include <string> +#include <vector> + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args); + +extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 2aae7c12d19..f38e9bde3ea 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; @@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() { return container; } -static void -rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) -{ - if (rwc == NULL) return; - - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -static ruby_whisper_abort_callback_container* -rb_whisper_abort_callback_container_allocate() { - ruby_whisper_abort_callback_container *container; - container = ALLOC(ruby_whisper_abort_callback_container); - container->context = NULL; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = Qnil; - container->is_interrupted = false; - return container; -} - -static bool +bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } -static bool -ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { - return !NIL_P(container->callback) || !NIL_P(container->callbacks); -} - typedef struct { const ruby_whisper_callback_container *container; struct whisper_state *state; @@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s } typedef struct { - const ruby_whisper_abort_callback_container *container; - struct whisper_state *state; + const ruby_whisper_callback_container *container; bool is_interrupted; } call_abort_callbacks_args; static void* call_abort_callbacks(void *v_args) { call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; - const ruby_whisper_abort_callback_container *container = args->container; - - if (container->is_interrupted) { - args->is_interrupted = true; - return NULL; - } + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) { if (NIL_P(container->callbacks)) { return NULL; } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { return NULL; } - for (int j = 0; j < callbacks_len; j++) { + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) { } static bool abort_callback(void * user_data) { - const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; - if (container->is_interrupted) { + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { return true; } - if (!ruby_whisper_abort_callback_container_is_present(container)) { + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { return false; } call_abort_callbacks_args args = { - container, - NULL, + data->callback_container, false }; rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); @@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) return; } - if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { - rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); - } - - if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { - rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); - } + // new_segment_callback is called only after multiple threads are joined + // progress_callback is not called when parallel if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } - - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (!NIL_P(log_callback)) { - rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); - } } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } + abort_callback_user_data->callback_container = rwp->abort_callback_container; rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; - rwp->abort_callback_container->is_interrupted = false; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } -/* - TODO: Set abort callback to trap SIGINT and SIGTERM -*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { check_thread_safety(rwp, n_processors); - register_callbacks(rwp, context); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -421,10 +376,10 @@ void rb_whisper_params_mark(void *p) { ruby_whisper_params *rwp = (ruby_whisper_params *)p; - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwp->progress_callback_container); + ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass) } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index ee0d66c4cc8..cf0372797d3 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -4,12 +4,12 @@ extern ID id___method__; extern ID id_to_enum; -static VALUE sym_start_time; -static VALUE sym_end_time; -static VALUE sym_text; -static VALUE sym_no_speech_prob; -static VALUE sym_speaker_turn_next; -static VALUE sym_n_tokens; +VALUE sym_start_time; +VALUE sym_end_time; +VALUE sym_text; +VALUE sym_no_speech_prob; +VALUE sym_speaker_turn_next; +VALUE sym_n_tokens; extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 37656af1c44..73f606ca476 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,6 +16,8 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); +extern VALUE full_parallel_body(VALUE rb_args); typedef struct{ struct whisper_context *context; @@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args) return NULL; } -typedef struct { - ruby_whisper_abort_callback_container *abort_callback_container; -} transcribe_ubf_args; - -static void -transcribe_ubf(void *rb_args) -{ - transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; - - args->abort_callback_container->is_interrupted = true; -} - /* * transcribe a single file * can emit to a block results @@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - - prepare_transcription(rwp, &self, n_processors); - - transcribe_without_gvl_args args = { - rw->context, - &rwp->params, - pcmf32.data(), - pcmf32.size(), - n_processors, - 0, - }; - transcribe_ubf_args ubf_args = { - rwp->abort_callback_container, - }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); - if (args.result != 0) { + + VALUE rb_result; + if (n_processors == 1) { + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + rb_result = full_body((VALUE)&args); + } else { + ruby_whisper_full_parallel_args parallel_args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + n_processors, + }; + rb_result = full_parallel_body((VALUE)¶llel_args); + } + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb deleted file mode 100644 index c3a134b773d..00000000000 --- a/bindings/ruby/lib/whisper/context.rb +++ /dev/null @@ -1,15 +0,0 @@ -module Whisper - class Context - def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| - srt << "#{index + 1}\n#{segment.to_srt_cue}\n" - } - end - - def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| - webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" - } - end - end -end diff --git a/bindings/ruby/lib/whisper/log_settable.rb b/bindings/ruby/lib/whisper/log_settable.rb new file mode 100644 index 00000000000..2f8218d26ee --- /dev/null +++ b/bindings/ruby/lib/whisper/log_settable.rb @@ -0,0 +1,36 @@ +require "mutex_m" + +module Whisper + module LogSettable + class << self + def extended(base) + base.extend Mutex_m + end + end + + private + + def start_log_callback_thread + return if @log_callback_thread&.alive? + + @log_callback_thread = Thread.new { + begin + while logs = drain_logs + begin + callback, user_data = synchronize {[@log_callback, @log_callback_user_data]} + next if callback.nil? + + logs.each do |(level, text)| + callback.call level, text, user_data + end + rescue => err + $stderr.puts err + end + end + rescue => err + $stderr.puts err + end + } + end + end +end diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 8eb57e5e8cf..ef92eb901c4 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -41,6 +41,8 @@ def base_cache_dir def cache path = cache_path + return path if cache_path.exist? + headers = {} headers["if-modified-since"] = path.mtime.httpdate if path.exist? request @uri, headers @@ -216,8 +218,18 @@ def escaping(path) @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") end + %w[ + parakeet-tdt-0.6b-v3-f16 + parakeet-tdt-0.6b-v3-f32 + parakeet-tdt-0.6b-v3-q4_0 + parakeet-tdt-0.6b-v3-q4_k + parakeet-tdt-0.6b-v3-q8_0 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin") + end + @coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models| - next if name.end_with?("-tdrz") || name.start_with?("silero-") + next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-") if matched = name.match(/\A(?<name>.*)-q\d_\d\z/) name = matched[:name] diff --git a/bindings/ruby/lib/whisper/output.rb b/bindings/ruby/lib/whisper/output.rb new file mode 100644 index 00000000000..1781af17a33 --- /dev/null +++ b/bindings/ruby/lib/whisper/output.rb @@ -0,0 +1,74 @@ +module Whisper + module Output + module Context + def to_srt + each_segment.with_index.reduce("") {|srt, (segment, index)| + srt << "#{index + 1}\n#{segment.to_srt_cue}\n" + } + end + + def to_webvtt + each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" + } + end + end + + module Segment + SRT_ESCAPES = { + "&" => "&", + "<" => "<", + ">" => ">", + } + SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) + private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE + + def to_srt_cue + "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" + end + + def to_webvtt_cue + "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" + end + + private + + def time_to_a(time) + sec, decimal_part = time.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + [hour, min, sec, decimal_part] + end + + def srt_time(time) + "%02d:%02d:%02d,%03d" % time_to_a(time) + end + + def srt_start_time + srt_time(start_time) + end + + def srt_end_time + srt_time(end_time) + end + + def srt_text + text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) + end + + def webvtt_time(time) + "%02d:%02d:%02d.%03d" % time_to_a(time) + end + + def webvtt_start_time + webvtt_time(start_time) + end + + def webvtt_end_time + webvtt_time(end_time) + end + + alias webvtt_text srt_text + end + end +end diff --git a/bindings/ruby/lib/whisper/segment.rb b/bindings/ruby/lib/whisper/segment.rb deleted file mode 100644 index dc187dcac36..00000000000 --- a/bindings/ruby/lib/whisper/segment.rb +++ /dev/null @@ -1,58 +0,0 @@ -module Whisper - class Segment - SRT_ESCAPES = { - "&" => "&", - "<" => "<", - ">" => ">", - } - SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) - private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE - - def to_srt_cue - "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" - end - - def to_webvtt_cue - "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" - end - - private - - def time_to_a(time) - sec, decimal_part = time.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - [hour, min, sec, decimal_part] - end - - def srt_time(time) - "%02d:%02d:%02d,%03d" % time_to_a(time) - end - - def srt_start_time - srt_time(start_time) - end - - def srt_end_time - srt_time(end_time) - end - - def srt_text - text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) - end - - def webvtt_time(time) - "%02d:%02d:%02d.%03d" % time_to_a(time) - end - - def webvtt_start_time - webvtt_time(start_time) - end - - def webvtt_end_time - webvtt_time(end_time) - end - - alias webvtt_text srt_text - end -end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index cbec4803820..c12e1fe55e5 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -40,7 +40,21 @@ module Whisper def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String + module Output + module Context + def to_srt: () -> String + def to_webvtt: () -> String + end + + module Segment + def to_srt_cue: () -> String + def to_webvtt_cue: () -> String + end + end + class Context + include Output::Context + def self.new: (String | path | ::URI::HTTP) -> instance # transcribe a single file @@ -139,17 +153,14 @@ module Whisper | (Whisper::Params, _Samples, ?Integer n_samples) -> self | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self - def to_srt: () -> String - def to_webvtt: () -> String - class Params def self.new: ( - use_gpu: boolish, - flash_attn: boolish, - gpu_device: Integer, - dtw_token_timestamps: boolish, - dtw_aheads_preset: Integer, - dtw_n_top: Integer | nil, + ?use_gpu: boolish, + ?flash_attn: boolish, + ?gpu_device: Integer, + ?dtw_token_timestamps: boolish, + ?dtw_aheads_preset: Integer, + ?dtw_n_top: Integer | nil, ) -> instance def use_gpu=: (boolish) -> boolish @@ -444,6 +455,9 @@ module Whisper def abort_on: { (Object user_data) -> boolish } -> void end + module LogSettable + end + class Model def self.pre_converted_models: () -> Hash[String, Model::URI] def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI] @@ -474,6 +488,8 @@ module Whisper end class Segment + include Output::Segment + type deconstructed_keys = { start_time: (Integer | nil), end_time: (Integer | nil), @@ -514,9 +530,6 @@ module Whisper # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] - def to_srt_cue: () -> String - def to_webvtt_cue: () -> String - # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # @@ -528,7 +541,7 @@ module Whisper def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys end - module Token + class Token type deconstructed_keys = { id: (Integer | nil), tid: (Integer | nil), @@ -598,6 +611,336 @@ module Whisper def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys end + module Parakeet + extend LogSettable + + VERSION: String + + # Control logging output. The default behavior is to print to stderr. + # + def self.log_set: (nil, Object? user_data) -> nil + | (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil + def self.system_info_str: () -> String + + class Context + include Output::Context + + # Load a Parakeet model from the given file path. + # + def self.new: (String | path | ::URI::HTTP, ?Params) -> instance + + # Transcribe a single audio file. + # + def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self + + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + # Not thread safe for the same context. + # + # The second argument `samples` must be an array of samples, respond to `:length`, + # or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # + def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self + + # Number of generated text segments. + # + def full_n_segments: () -> Integer + + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t0(3) # => 1668 (16680 ms) + # + def full_get_segment_t0: (Integer segment_index) -> Integer + + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t1(3) # => 1668 (16680 ms) + # + def full_get_segment_t1: (Integer segment_index) -> Integer + + # Text of a segment indexed by `segment_index`. + # + # full_get_segment_text(3) # => "ask not what your country can do for you, ..." + # + def full_get_segment_text: (Integer segment_index) -> String + + # Number of tokens in the segment indexed by `segment_index`. + # + def full_n_tokens: (Integer segment_index) -> Integer + + # Text of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_text: (Integer segment_index, Integer token_index) -> String + + # Token id of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer + + # Probability of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_p: (Integer segment_index, Integer token_index) -> Float + + # Token data of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_data: (Integer segment_index, Integer token_index) -> Token + + def model: () -> Model + + # Yields each Whisper::Parakeet::Segment: + # + # parakeet.transcribe("path/to/audio.wav", params) + # parakeet.each_segment do |segment| + # puts segment.text + # end + # + # Returns an `Enumerator` if no block given: + # + # parakeet.transcribe("path/to/audio.wav", params) + # enum = parakeet.each_segment + # enum.to_a # => [#<Whisper::Parakeet::Segment>, ...] + # + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end + end + + class Params + def self.new: ( + ?n_threads: Integer, + ?offset_ms: Integer, + ?duration_ms: Integer, + ?no_context: boolish, + ?audio_ctx: Integer, + ?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void, + ?new_segment_callback_user_data: Object, + ?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void, + ?new_token_callback_user_data: Object, + ?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void, + ?progress_callback_user_data: Object, + ?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish, + ?encoder_begin_callback_user_data: Object, + ?abort_callback: ^(Object user_data) -> boolish, + ?abort_callback_user_data: Object + ) -> instance + + # Number of threads to use. + # + def n_threads=: (Integer) -> Integer + def n_threads: () -> Integer + + # Start offset in ms. + # + def offset_ms=: (Integer) -> Integer + def offset_ms: () -> Integer + + # Audio duration to process in ms. + # + def duration_ms=: (Integer) -> Integer + def duration_ms: () -> Integer + + # If `true`, does not use past transcription (if any) as context. + # + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + + # Overwrite the audio context size. `0` uses the default value. + # + def audio_ctx=: (Integer) -> Integer + def audio_ctx: () -> Integer + + # Sets new segment callback, called for every newly generated text segment. + # + # params.new_segment_callback = ->(context, _, n_new, user_data) { + # # ... + # } + # + def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) + def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of new segment callback. + # + def new_segment_callback_user_data=: (Object?) -> Object? + def new_segment_callback_user_data: () -> Object? + + # Sets token callback, called for every newly predicted token. + # + def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) + def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of token callback. + # + def new_token_callback_user_data=: (Object?) -> Object? + def new_token_callback_user_data: () -> Object? + + # Sets progress callback, called on each progress update. + # + # +progress+ is an Integer between 0 and 100. + # + def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) + def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of progress callback. + # + def progress_callback_user_data=: (Object?) -> Object? + def progress_callback_user_data: () -> Object? + + # Sets encoder begin callback, called each time before the encoder starts. + # + # If it returns `false`, the computation is aborted. + # + def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) + def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object?) -> Object? + def encoder_begin_callback_user_data: () -> Object? + + # Sets abort callback, called each time before ggml computation starts. + # + def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish) + def abort_callback: () -> ((^(Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of abort callback. + # + def abort_callback_user_data=: (Object?) -> Object? + def abort_callback_user_data: () -> Object? + + # Hook called on new segment. Yields each Whisper::Parakeet::Segment. + # + def on_new_segment: { (Segment) -> void } -> void + + # Hook called on new token. Yields each Whisper::Parakeet::Token. + # + def on_new_token: { (Token) -> void } -> void + + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. + # + def on_progress: { (Integer progress) -> void } -> void + + # Hook called each time before the encoder starts. + # + def on_encoder_begin: { () -> boolish } -> void + + # Call block to determine whether abort or not. Return `true` when you want to abort. + # + def abort_on: { () -> boolish } -> void + end + + class Segment + include Output::Segment + + type deconstructed_keys = { + start_time: (Integer | nil), + end_time: (Integer | nil), + text: (String | nil) + } + + # Start time in milliseconds. + # + def start_time: () -> Integer + + # End time in milliseconds. + # + def end_time: () -> Integer + + # Text of the segment. + # + def text: () -> String + + # Yields each Whisper::Parakeet::Token: + # + # parakeet.each_segment.first.each_token do |token| + # p token + # end + # + # Returns an `Enumerator` if no block is given: + # + # parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...] + # + def each_token: { (Token) -> void } -> void + | () -> Enumerator[Token] + + # Possible keys: `:start_time`, `:end_time`, `:text` + # + def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys + end + + class Token + type deconstructed_keys = { + id: (Integer | nil), + duration_idx: (Integer | nil), + duration_value: (Integer | nil), + frame_index: (Integer | nil), + probability: (Float | nil), + log_probability: (Float | nil), + start_time: (Integer | nil), + end_time: (Integer | nil), + word_start: ((true | false) | nil), + text: (String | nil), + } + + # Token ID. + # + def id: () -> Integer + + # Index into the model's durations array. + # + def duration_idx: () -> Integer + + # Actual duration value. + # + def duration_value: () -> Integer + + # Frame index of the token. + # + def frame_index: () -> Integer + + # Probability of the token. + # + def probability: () -> Float + + # Log probability of the token. + # + def log_probability: () -> Float + + # Start time of the token in milliseconds. + # + def start_time: () -> Integer + + # End time of the token in milliseconds. + # + def end_time: () -> Integer + + # Whether this token is the start of a word. + # + def word_start?: () -> (true | false) + + # Get the token text of the token. + # + def text: () -> String + + def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys + end + + class Model + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + end + end + module VAD class Params def self.new: ( diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb index 56cd3849fdd..5e37ad98596 100644 --- a/bindings/ruby/test/helper.rb +++ b/bindings/ruby/test/helper.rb @@ -5,6 +5,8 @@ class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "fixtures", "jfk.wav") + Parakeet = Whisper::Parakeet + class << self def whisper return @whisper if @whisper diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb index a7f49245ade..6490c8abb48 100644 --- a/bindings/ruby/test/test_callback.rb +++ b/bindings/ruby/test/test_callback.rb @@ -129,6 +129,7 @@ def test_encoder_begin_callback_abort return false } @whisper.transcribe(@audio, @params) + sleep 0.5 # wait for logs dequeued assert_match(/encoder_begin_callback returned false - aborting/, logs.join) Whisper.log_set ->(level, buffer, user_data) {}, nil end diff --git a/bindings/ruby/test/test_parakeet.rb b/bindings/ruby/test/test_parakeet.rb new file mode 100644 index 00000000000..bfd57076f56 --- /dev/null +++ b/bindings/ruby/test/test_parakeet.rb @@ -0,0 +1,28 @@ +require_relative "helper" +require "stringio" + +class TestParakeet < TestBase + def test_log_set + log_callback = Parakeet.instance_variable_get("@log_callback") + user_data = Parakeet.instance_variable_get("@log_callback_user_data") + + $stdout = StringIO.new + Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil + Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + sleep 0.1 + $stdout.rewind + logs = $stdout.string + assert_match /loading model from/, logs + ensure + $stdout = STDOUT + Parakeet.log_set log_callback, user_data + end + + def test_system_info_str + assert_match /\APARAKEET : /, Parakeet.system_info_str + end + + def test_version + assert_instance_of String, Parakeet::VERSION + end +end diff --git a/bindings/ruby/test/test_parakeet_callback.rb b/bindings/ruby/test/test_parakeet_callback.rb new file mode 100644 index 00000000000..1209e960f09 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_callback.rb @@ -0,0 +1,107 @@ +require_relative "helper" + +class TestParakeetCallback < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + @params = Parakeet::Params.new + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @parakeet, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0 + end + } + + @parakeet.transcribe AUDIO, @params + end + + def test_on_new_segment + seg = nil + index = 0 + @params.on_new_segment do |segment| + assert_instance_of Parakeet::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text) + end + index += 1 + end + @parakeet.transcribe AUDIO, @params + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text + end + + def test_on_new_token + index = 0 + @params.on_new_token do |token| + assert_instance_of Parakeet::Token, token + if index == 0 + assert_instance_of Integer, token.start_time + assert_match "▁And", token.text + end + index += 1 + end + + @parakeet.transcribe AUDIO, @params + end + + def test_on_progress + first = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + end + + @parakeet.transcribe AUDIO, @params + + assert_equal 0, first + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + + @parakeet.transcribe AUDIO, @params + + assert i > 0 + end + + def test_abort_on + do_abort = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match?(/ask/) + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + + @parakeet.transcribe(AUDIO, @params) rescue nil + + assert i > 0 + end +end diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb new file mode 100644 index 00000000000..2d039ce75f5 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -0,0 +1,116 @@ +require_relative "helper" +require "stringio" + +class TestParakeetContext < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Context, @parakeet + end + + def test_new_with_params + log_callback = Parakeet.instance_variable_get(:@log_callback) + user_data = Parakeet.instance_variable_get(:@log_callback_user_data) + begin + logs = "" + Parakeet.log_set proc {|level, message| logs << message}, nil + params = Parakeet::Context::Params.new(use_gpu: false) + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params) + assert_instance_of Parakeet::Context, parakeet + assert_match /use gpu\s+=\s+0/, logs + ensure + Parakeet.log_set log_callback, user_data + end + end + + sub_test_case "full" do + def setup + super + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} + end + + def test_full + @parakeet.full @params, @samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text + end + + def test_full_without_length + @parakeet.full(@params, @samples) + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator + samples = @samples.each + @parakeet.full @params, samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @parakeet.full @params, samples + end + end + + def test_full_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @parakeet.full @params, samples, 11 + end + end + + def test_full_with_memory_view + samples = JFKReader.new(AUDIO) + @parakeet.full @params, samples + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @parakeet.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end + end +end diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 00000000000..fcd0f2410f7 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end diff --git a/bindings/ruby/test/test_parakeet_model.rb b/bindings/ruby/test/test_parakeet_model.rb new file mode 100644 index 00000000000..5343b35ed8e --- /dev/null +++ b/bindings/ruby/test/test_parakeet_model.rb @@ -0,0 +1,21 @@ +require_relative "helper" + +class TestParakeetModel < TestBase + def test_model + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + assert_instance_of Parakeet::Model, parakeet.model + end + + def test_attributes + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + model = parakeet.model + + assert_equal 10, model.n_vocab + assert_equal 3200, model.n_audio_ctx + assert_equal 8, model.n_audio_state + assert_equal 2, model.n_audio_head + assert_equal 1, model.n_audio_layer + assert_equal 16, model.n_mels + assert_equal 0, model.ftype + end +end diff --git a/bindings/ruby/test/test_parakeet_params.rb b/bindings/ruby/test/test_parakeet_params.rb new file mode 100644 index 00000000000..dc651f7ab12 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_params.rb @@ -0,0 +1,78 @@ +require_relative "helper" +require "etc" + +class TestParakeetParams < TestBase + PARAM_NAMES = [ + :n_threads, + :offset_ms, + :duration_ms, + :no_context, + :audio_ctx + ] + + def setup + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Params, @params + end + + def test_n_threads + assert_equal [4, Etc.nprocessors].min, @params.n_threads + + @params.n_threads = 1 + assert_equal 1, @params.n_threads + end + + def test_offset_ms + assert_equal 0, @params.offset_ms + + @params.offset_ms = 10_000 + assert_equal 10_000, @params.offset_ms + end + + def test_duration_ms + assert_equal 0, @params.duration_ms + + @params.duration_ms = 60_000 + assert_equal 60_000, @params.duration_ms + end + + def test_no_context + assert_equal true, @params.no_context + + @params.no_context = false + assert_equal false, @params.no_context + end + + def test_audio_ctx + assert_equal 0, @params.audio_ctx + + @params.audio_ctx = 1 + assert_equal 1, @params.audio_ctx + end + + def test_new_with_kw_args + params = Parakeet::Params.new(n_threads: 1) + assert_equal 1, params.n_threads + assert_equal 0, params.offset_ms + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer] + default_value + 1 + end + params = Parakeet::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + assert_equal @params.send(name), params.send(name) + end + end +end diff --git a/bindings/ruby/test/test_parakeet_segment.rb b/bindings/ruby/test/test_parakeet_segment.rb new file mode 100644 index 00000000000..d5b99bd5ee6 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_segment.rb @@ -0,0 +1,42 @@ +require_relative "helper" + +class TestParakeetSegment < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @parakeet.transcribe AUDIO, Parakeet::Params.new + end + + def test_segment + whole_text = "" + @parakeet.each_segment do |segment| + assert_instance_of Parakeet::Segment, segment + assert_kind_of Integer, segment.start_time + assert segment.end_time >= segment.start_time + assert_kind_of String, segment.text + whole_text << segment.text + end + assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text) + end + + def test_deconstruct_keys + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text]) + end + + def test_deconstruct_keys_with_nil + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys(nil) + end +end diff --git a/bindings/ruby/test/test_parakeet_token.rb b/bindings/ruby/test/test_parakeet_token.rb new file mode 100644 index 00000000000..6f0b8b5a37c --- /dev/null +++ b/bindings/ruby/test/test_parakeet_token.rb @@ -0,0 +1,73 @@ +require_relative "helper" + +class TestParakeetToken < TestBase + ATTRS = %i[ + id + duration_idx + duration_value + frame_index + probability + log_probability + start_time + end_time + word_start? + text + ] + + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + params = Parakeet::Params.new + parakeet.transcribe AUDIO, params + @segment = parakeet.each_segment.first + end + + def test_each_token + i = 0 + @segment.each_token do |token| + i += 1 + assert_instance_of Parakeet::Token, token + end + assert_equal 38, i + end + + def test_each_token_without_block + assert_instance_of Enumerator, @segment.each_token + end + + def test_token + token = @segment.each_token.first + + assert_instance_of Parakeet::Token, token + assert_instance_of Integer, token.id + assert_instance_of Integer, token.duration_idx + assert_instance_of Integer, token.duration_value + assert_instance_of Integer, token.frame_index + assert_instance_of Float, token.probability + assert_instance_of Float, token.log_probability + assert_instance_of Integer, token.start_time + assert_instance_of Integer, token.end_time + assert_instance_of String, token.text + end + + def test_text + assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."], + @segment.each_token.collect(&:text) + end + + def test_deconstruct_keys_with_nil + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(nil) + end + + def test_deconstruct_keys_with_keys + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(expected.keys) + end +end diff --git a/bindings/ruby/test/test_vad_segment.rb b/bindings/ruby/test/test_vad_segment.rb index 7348562cb15..6d66c27fd32 100644 --- a/bindings/ruby/test/test_vad_segment.rb +++ b/bindings/ruby/test/test_vad_segment.rb @@ -9,7 +9,7 @@ def test_initialize end assert_raise do - segments.end_time + segment.end_time end assert_raise do diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index f7e25239d5d..082547e7c08 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -149,6 +149,7 @@ def test_log_set } Whisper.log_set log_callback, user_data Whisper::Context.new("base.en") + sleep 0.1 # wait for logs dequeued assert logs.length > 30 logs.each do |log| diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2d952222f29..301ecfcc13d 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "test/"} s.extensions << 'ext/extconf.rb' - s.required_ruby_version = '>= 3.1.0' + s.required_ruby_version = '>= 3.3.0' #### Documentation and testing. s.homepage = 'https://github.com/ggml-org/whisper.cpp' From 86c40c3bd6fc86f1187fb751d111b49e0fc18e84 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Wed, 17 Jun 2026 11:36:57 +0200 Subject: [PATCH 286/289] release : v1.9.0 (#3886) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dff25f25a34..8527d6d9bed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.7) +project("whisper.cpp" VERSION 1.9.0) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index 19fdc70daab..a32d9b61382 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.0](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.0) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 7c66c730c6c..b777591a4e3 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.7", + "version": "1.9.0", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 200b1197907545a88c5a00fb15f52e2cf88af6f5 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 18 Jun 2026 14:49:08 +0200 Subject: [PATCH 287/289] ci : add GGML_NATIVE=OFF and GGML_BMI2=OFF to windows-blas (#3891) * ci : add GGML_NATIVE=OFF and build all cpu-variants This commit adds -DGGML_BACKEND_DL=ON, -DGGML_NATIVE=OFF, and -DGGML_CPU_ALL_VARIANTS=ON to the releases. The motivation for this is that currently the Windows BLAS build uses the native CPU instructions and if target systems do not support these instructions, the build will fail like the linked issue reports. Resolves: https://github.com/ggml-org/whisper.cpp/issues/3889 * ci : update ubuntu-cpu release job for all variants [no ci] This commit enables the ubuntu-cpu job to include all cpu variants and ensures that the ggml backend libraries are built into the bin directory similar to how llama.cpp does it. The following is a build on my fork with this change: https://github.com/danbev/whisper.cpp/releases/tag/untagged-fc3c71f0bf0f7bf19d19 --- .github/workflows/release.yml | 15 +++++++++++---- CMakeLists.txt | 1 + 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ef2c3083c9f..8dcfeb9827c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,9 +115,11 @@ jobs: run: | cmake -B build \ -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ + -DGGML_BACKEND_DL=ON \ -DGGML_NATIVE=OFF \ - ${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }} + ${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }} cmake --build build --config Release -j $(nproc) - name: Pack artifacts @@ -173,7 +175,7 @@ jobs: -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} -DGGML_NATIVE=OFF - -DGGML_BMI2=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - name: Build run: | @@ -287,6 +289,8 @@ jobs: -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - name: Build run: | @@ -490,7 +494,10 @@ jobs: -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ -DSDL2_DIR="%SDL2_DIR%" ^ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" + -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU_ALL_VARIANTS=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% diff --git a/CMakeLists.txt b/CMakeLists.txt index 8527d6d9bed..1f95e175af4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(WHISPER_STANDALONE ON) From f049fff95a089aa9969deb009cdd4892b3e74916 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Fri, 19 Jun 2026 06:12:37 +0200 Subject: [PATCH 288/289] release : v1.9.1 (#3892) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f95e175af4..26037c26538 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.9.0) +project("whisper.cpp" VERSION 1.9.1) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index a32d9b61382..0e2d5f100d5 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.9.0](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.0) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index b777591a4e3..09829326605 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.9.0", + "version": "1.9.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 195e85689bc8ad825929cf11160ed53cbbea8945 Mon Sep 17 00:00:00 2001 From: Freddy Martinez Garcia <freddy311082@gmail.com> Date: Mon, 29 Jun 2026 21:44:06 -0300 Subject: [PATCH 289/289] QVAC-21582 chore: keep Tether .github, drop upstream CI workflows from merge The v1.9.1 merge pulled in upstream's GitHub Actions workflows (build-android/clang/coreml/cpu/freebsd/gcc/macos/quantize/sanitize/ self-hosted/sycl/vad/wasm/windows, release.yml, ccache-clear action) and modified bindings-*/docker/examples + renamed examples-wasm.yml. None of these are used by the Tether fork. Restore .github to exactly match master so the PR introduces no CI-workflow changes. Co-authored-by: Cursor <cursoragent@cursor.com> --- .github/actions/ccache-clear/action.yml | 22 - .github/workflows/bindings-go.yml | 8 +- .github/workflows/bindings-ruby.yml | 17 +- .github/workflows/build-android.yml | 80 --- .github/workflows/build-clang.yml | 121 ---- .github/workflows/build-coreml.yml | 65 -- .github/workflows/build-cpu.yml | 173 ----- .github/workflows/build-freebsd.yml | 47 -- .github/workflows/build-gcc.yml | 167 ----- .github/workflows/build-macos.yml | 72 -- .github/workflows/build-quantize.yml | 48 -- .github/workflows/build-sanitize.yml | 82 --- .github/workflows/build-self-hosted.yml | 116 ---- .github/workflows/build-sycl.yml | 150 ---- .github/workflows/build-vad.yml | 50 -- .github/workflows/build-wasm.yml | 65 -- .github/workflows/build-windows.yml | 74 -- .github/workflows/docker.yml | 44 +- ...oy-examples-wasm.yml => examples-wasm.yml} | 10 +- .github/workflows/examples.yml | 12 +- .github/workflows/release.yml | 653 ------------------ 21 files changed, 40 insertions(+), 2036 deletions(-) delete mode 100644 .github/actions/ccache-clear/action.yml delete mode 100644 .github/workflows/build-android.yml delete mode 100644 .github/workflows/build-clang.yml delete mode 100644 .github/workflows/build-coreml.yml delete mode 100644 .github/workflows/build-cpu.yml delete mode 100644 .github/workflows/build-freebsd.yml delete mode 100644 .github/workflows/build-gcc.yml delete mode 100644 .github/workflows/build-macos.yml delete mode 100644 .github/workflows/build-quantize.yml delete mode 100644 .github/workflows/build-sanitize.yml delete mode 100644 .github/workflows/build-self-hosted.yml delete mode 100644 .github/workflows/build-sycl.yml delete mode 100644 .github/workflows/build-vad.yml delete mode 100644 .github/workflows/build-wasm.yml delete mode 100644 .github/workflows/build-windows.yml rename .github/workflows/{deploy-examples-wasm.yml => examples-wasm.yml} (85%) delete mode 100644 .github/workflows/release.yml diff --git a/.github/actions/ccache-clear/action.yml b/.github/actions/ccache-clear/action.yml deleted file mode 100644 index d38587efaf8..00000000000 --- a/.github/actions/ccache-clear/action.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: "ccache-clear" -description: "Delete all GitHub Actions caches matching a key prefix" -inputs: - key: - description: "Cache key prefix to match and delete" - required: true - -runs: - using: "composite" - steps: - - name: Clear caches - shell: bash - run: | - CACHES=$(gh cache list --key "ccache-${{ inputs.key }}" --json id,key --jq '.[] | "\(.id) \(.key)"' 2>/dev/null) - if [ -z "$CACHES" ]; then - echo "No caches found with key prefix: ${{ inputs.key }}" - exit 0 - fi - while read -r id key; do - echo "Deleting cache: $id ($key)" - gh cache delete "$id" - done <<< "$CACHES" diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 91f869e99cf..83473e4636a 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -3,20 +3,20 @@ on: push: paths: - bindings/go/** - - include/whisper.h + - whisper.h pull_request: paths: - bindings/go/** - - include/whisper.h + - whisper.h jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 + - uses: actions/setup-go@v6 with: go-version: '^1.23' - - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + - uses: actions/checkout@v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 8cdb7a810f7..c3f158e26e4 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -4,19 +4,8 @@ on: push: branches: - master - paths: - - bindings/ruby/** - - include/whisper.h - - examples/common-whisper.h - - ggml/include/ggml.h - pull_request: types: [opened, synchronize, reopened] - paths: - - bindings/ruby/** - - include/whisper.h - - examples/common-whisper.h - - ggml/include/ggml.h jobs: ubuntu-22: @@ -25,8 +14,8 @@ jobs: run: working-directory: bindings/ruby steps: - - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 + - uses: ruby/setup-ruby@v1 with: - ruby-version: '3.3' - - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + ruby-version: '3.2' + - uses: actions/checkout@v6 - run: rake test diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml deleted file mode 100644 index 571c35872c8..00000000000 --- a/.github/workflows/build-android.yml +++ /dev/null @@ -1,80 +0,0 @@ -name: CI (android) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-android.yml', - '**/CMakeLists.txt', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.java'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - android: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - with: - path: whisper - - - name: Install Java - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 - with: - distribution: zulu - java-version: 21 - - - name: Setup Android SDK - uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 - - - name: Build - run: | - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - - name: Build with external ggml - run: | - export PATH_TO_GGML=$PWD/ggml - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - android_java: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: set up JDK 11 - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 - with: - java-version: '11' - distribution: 'temurin' - cache: gradle - - - name: Setup Android SDK - uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 - with: - cmdline-tools-version: 9.0 - - - name: Build - run: | - cd examples/whisper.android.java - chmod +x ./gradlew - ./gradlew assembleRelease diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml deleted file mode 100644 index 20b7fec6494..00000000000 --- a/.github/workflows/build-clang.yml +++ /dev/null @@ -1,121 +0,0 @@ -name: CI (clang) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-clang.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -env: - ubuntu_image: "ubuntu:22.04" - -jobs: - ubuntu-22-clang: - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - # TODO: arm/v7 disabled due to clang bug - # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Set CCACHE_DIR - run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: clang-${{ matrix.arch }}-${{ matrix.build }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -v ${CCACHE_DIR}:${CCACHE_DIR} \ - -e CCACHE_DIR=${CCACHE_DIR} \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y clang build-essential cmake libsdl2-dev git ccache - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang-arm64: - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: clang-arm64-${{ matrix.build }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y clang build-essential cmake libsdl2-dev git - - - name: Build and Test - run: | - cmake . -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER=clang \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml deleted file mode 100644 index 8dedd7819ed..00000000000 --- a/.github/workflows/build-coreml.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: CI (coreml) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - tags: - - 'v*' - paths: ['.github/workflows/build-coreml.yml', - '**/CMakeLists.txt', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal'] - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - -jobs: - coreml-base-en: - runs-on: macos-latest - - steps: - - name: Checkout with full history - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - with: - fetch-depth: 0 - - - name: Set environment variables - id: set_vars - run: | - BUILD_NUMBER=$(git rev-list --count HEAD) - SHORT_HASH=$(git rev-parse --short=7 HEAD) - if [[ "${{ github.ref_type }}" == "tag" ]]; then - TAG_NAME="${{ github.ref_name }}" - elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - TAG_NAME="b${BUILD_NUMBER}" - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" - fi - echo "MODEL_NAME=base.en" >> $GITHUB_ENV - echo "GEN_MODEL_NAME=whisper-${TAG_NAME}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV - - - name: Download model - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} - - - name: Generate CoreML model - run: | - python3.11 -m venv venv - source venv/bin/activate - pip install ane_transformers openai-whisper coremltools - ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml deleted file mode 100644 index e2b74881ea5..00000000000 --- a/.github/workflows/build-cpu.yml +++ /dev/null @@ -1,173 +0,0 @@ -name: CI (cpu) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-cpu.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -# TODO: simplify the following jobs using a matrix -jobs: - ggml-ci-x64-cpu-low-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ggml-ci-x64-cpu-low-perf - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-low-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ggml-ci-arm64-cpu-low-perf - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-cpu-high-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ggml-ci-x64-cpu-high-perf - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ggml-ci-arm64-cpu-high-perf - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf-sve: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ggml-ci-arm64-cpu-high-perf-sve - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml deleted file mode 100644 index 64e78ad62f8..00000000000 --- a/.github/workflows/build-freebsd.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: CI (freebsd) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-freebsd.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - freeBSD-latest: - runs-on: macos-13 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Build - uses: cross-platform-actions/action@fe0167d8082ac584754ef3ffb567fded22642c7d # v0.27.0 - with: - operating_system: freebsd - version: '14.2' - run: | - sudo pkg update - sudo pkg install -y gmake sdl2 cmake git - cmake -B build - cmake --build build --config Release diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml deleted file mode 100644 index 53c1b2d783c..00000000000 --- a/.github/workflows/build-gcc.yml +++ /dev/null @@ -1,167 +0,0 @@ -name: CI (gcc) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-gcc.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -env: - ubuntu_image: "ubuntu:22.04" - -jobs: - ubuntu-22-gcc: - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Set CCACHE_DIR - run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: gcc-${{ matrix.arch }}-${{ matrix.build }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -v ${CCACHE_DIR}:${CCACHE_DIR} \ - -e CCACHE_DIR=${CCACHE_DIR} \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git ccache - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_NATIVE=OFF \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-arm64: - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: gcc-arm64-${{ matrix.build }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake libsdl2-dev git - - - name: Configure CMake - run: | - cmake . \ - -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - - - name: Build and Test - run: | - make - ctest -L gh --output-on-failure - - ubuntu-22-gcc-arm-v7: - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Set CCACHE_DIR - run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: gcc-${{ matrix.arch }}-${{ matrix.build }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Set up QEMU - uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -v ${CCACHE_DIR}:${CCACHE_DIR} \ - -e CCACHE_DIR=${CCACHE_DIR} \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git ccache - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv7-a+fp \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - make - ctest -L gh --output-on-failure' diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml deleted file mode 100644 index 8b209e4eec8..00000000000 --- a/.github/workflows/build-macos.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: CI (macOS) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-macos.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - macOS-latest: - runs-on: macOS-latest - - strategy: - matrix: - destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: macos-${{ matrix.destination }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Dependencies - run: | - brew update - cmake --version - brew install sdl2 - - - name: Build - run: | - sysctl -a - cmake -B build -G Xcode \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" - cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml deleted file mode 100644 index 1c9576af7f1..00000000000 --- a/.github/workflows/build-quantize.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: CI (quantize) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-quantize.yml', - '**/CMakeLists.txt', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - quantize: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: quantize-ubuntu-22 - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Test quantize - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} - run: | - ./models/download-ggml-model.sh tiny.en - cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - cmake --build build --config Release - ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml deleted file mode 100644 index e517f7bade4..00000000000 --- a/.github/workflows/build-sanitize.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: CI (sanitize) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-sanitize.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' - - 'bindings/go/**' - - 'examples/addon.node/**' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - ubuntu-22-gcc-sanitized: - runs-on: ubuntu-22.04 - - continue-on-error: true - - strategy: - fail-fast: false - matrix: - sanitizer: [ADDRESS, THREAD, UNDEFINED] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: sanitize-${{ matrix.sanitizer }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake git - - - name: Build (undefined) - if: ${{ matrix.sanitizer == 'UNDEFINED' }} - run: | - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - - - name: Build - if: ${{ matrix.sanitizer == 'ADDRESS' }} - run: | - cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON - make - - - name: Build (no OpenMP) - if: ${{ matrix.sanitizer == 'THREAD' }} - run: | - cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - - - name: Test - if: ${{ matrix.sanitizer != 'UNDEFINED' }} - run: | - ctest -L gh --output-on-failure diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml deleted file mode 100644 index 2286b63d6e7..00000000000 --- a/.github/workflows/build-self-hosted.yml +++ /dev/null @@ -1,116 +0,0 @@ -name: CI (self-hosted) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: [ - '.github/workflows/build.yml', - '**/CMakeLists.txt', - '**/.cmake', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal', - '**/*.comp' - ] - - pull_request: - types: [opened, synchronize, reopened] - paths: [ - '.github/workflows/build-self-hosted.yml', - '**/CMakeLists.txt', - '**/.cmake', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal', - '**/*.comp' - ] - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - gpu-cuda: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Test - id: ggml-ci - run: | - nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - gpu-vulkan-nvidia-cm: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - gpu-vulkan-nvidia-cm2: - runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - gpu-metal: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - gpu-vulkan: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml deleted file mode 100644 index e5361645f1e..00000000000 --- a/.github/workflows/build-sycl.yml +++ /dev/null @@ -1,150 +0,0 @@ -name: CI (sycl) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-sycl.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - ubuntu-22-cmake-sycl: - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: sycl-${{ matrix.arch }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - id: cmake_build - env: - CCACHE_SLOPPINESS: time_macros - CCACHE_NODIRECT: 1 - run: | - source /opt/intel/oneapi/setvars.sh - export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" - mkdir build - cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. - cmake --build . --config Release -j $(nproc) - - ubuntu-22-cmake-sycl-fp16: - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: sycl-fp16-${{ matrix.arch }} - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - id: cmake_build - env: - CCACHE_SLOPPINESS: time_macros - CCACHE_NODIRECT: 1 - run: | - source /opt/intel/oneapi/setvars.sh - export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" - mkdir build - cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. - cmake --build . --config Release -j $(nproc) diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml deleted file mode 100644 index dd0efa33efe..00000000000 --- a/.github/workflows/build-vad.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: CI (vad) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-vad.yml', - '**/CMakeLists.txt', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - vad: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: vad-ubuntu-latest - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - shell: bash - run: | - cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - cmake --build build --config Release - - - name: Test - shell: bash - run: | - ctest -R ^test-vad$ --test-dir build --output-on-failure -VV diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml deleted file mode 100644 index c17a44ae455..00000000000 --- a/.github/workflows/build-wasm.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: CI (wasm) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-wasm.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - emscripten: - runs-on: ubuntu-22.04 - - strategy: - matrix: - build: [Release] - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Setup emsdk - uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - - - name: Verify - run: emcc -v - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: wasm-ubuntu-22 - evict-old-files: 1d - save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - - - name: Build - env: - CCACHE_SLOPPINESS: time_macros,include_file_mtime,include_file_ctime - CCACHE_COMPILERCHECK: content - run: | - emcmake cmake -B build -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - "-DCMAKE_C_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" \ - "-DCMAKE_CXX_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" - cmake --build build -j $(nproc) diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml deleted file mode 100644 index 76b7a7370ce..00000000000 --- a/.github/workflows/build-windows.yml +++ /dev/null @@ -1,74 +0,0 @@ -name: CI (windows) - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - paths: ['.github/workflows/build-windows.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - windows-msys2: - runs-on: windows-latest - - strategy: - fail-fast: false - matrix: - include: - - { sys: UCRT64, env: ucrt-x86_64, compiler: gcc, build: Release } - - { sys: CLANG64, env: clang-x86_64, compiler: clang, build: Release } - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 - with: - update: true - msystem: ${{matrix.sys}} - install: >- - mingw-w64-${{matrix.env}}-${{matrix.compiler}} - mingw-w64-${{matrix.env}}-cmake - mingw-w64-${{matrix.env}}-SDL2 - mingw-w64-${{matrix.env}}-openblas - - - name: Build using CMake - shell: msys2 {0} - run: | - cmake -B build -DWHISPER_SDL2=ON - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - - name: Clean after building using CMake - shell: msys2 {0} - run: | - rm -rf build - - - name: Build using CMake w/ OpenBLAS - shell: msys2 {0} - run: | - cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS - cmake --build build --config ${{ matrix.build }} -j $(nproc) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 2d95e1a697f..6c0de0ece70 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,39 +1,43 @@ name: Publish Docker image on: - workflow_dispatch: # allows manual triggering - schedule: - # Rebuild daily rather than on every push because it is expensive - - cron: '12 4 * * *' + pull_request: + push: + branches: + - master jobs: push_to_registry: name: Push Docker image to Docker Hub + if: github.event.pull_request.draft == false - runs-on: ${{ matrix.config.runs_on }} + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.sha }} strategy: fail-fast: false matrix: config: - - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } - - { tag: "main-arm64", dockerfile: ".devops/main.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } - - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } - - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } - - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } - - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } - - { tag: "main-vulkan-arm64", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } + - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } + - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } + - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } + - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64" } steps: - name: Check out the repo - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + uses: actions/checkout@v6 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + with: + image: tonistiigi/binfmt:qemu-v7.0.0-28 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 + uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub - uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -58,16 +62,16 @@ jobs: id: tags run: | TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}" - TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" + if [ "${{ github.event_name }}" == "push" ]; then + TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" + fi echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7 + uses: docker/build-push-action@v6 with: context: . - push: true + push: ${{ github.event_name == 'push' }} platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} - secrets: | - HF_TOKEN=${{ secrets.HF_TOKEN }} diff --git a/.github/workflows/deploy-examples-wasm.yml b/.github/workflows/examples-wasm.yml similarity index 85% rename from .github/workflows/deploy-examples-wasm.yml rename to .github/workflows/examples-wasm.yml index 55df14720b1..927438cdad8 100644 --- a/.github/workflows/deploy-examples-wasm.yml +++ b/.github/workflows/examples-wasm.yml @@ -22,13 +22,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + uses: actions/checkout@v6 - name: Setup Pages - uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5 + uses: actions/configure-pages@v5 - name: Setup emsdk - uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 + uses: mymindstorm/setup-emsdk@v14 - name: Build WASM Examples # Enable for real build later in whisper.cpp @@ -88,10 +88,10 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4 + uses: actions/upload-pages-artifact@v4 with: path: ./staging - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4 + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index ac811712e78..1c9ade5a300 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -1,15 +1,13 @@ name: Examples Tests on: push: - branches: - - master paths: - examples/addon.node/** - - include/whisper.h + - whisper.h pull_request: paths: - examples/addon.node/** - - include/whisper.h + - whisper.h jobs: addon_node-ubuntu-22: @@ -19,7 +17,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + uses: actions/checkout@v6 - name: Dependencies run: | @@ -29,7 +27,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 + uses: actions/setup-node@v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' @@ -42,8 +40,6 @@ jobs: run: npx cmake-js compile -T addon.node -B Release - name: Download test model - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | bash ./models/download-ggml-model.sh base.en - name: Test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index 8dcfeb9827c..00000000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,653 +0,0 @@ -name: Release - -on: - workflow_dispatch: - inputs: - create_release: - description: 'Create new release' - required: true - type: boolean - pre_release_tag: - description: 'Pre-release tag name' - required: false - type: string - - push: - tags: - - 'v*' - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -permissions: - contents: write # for creating release - -jobs: - determine-tag: - runs-on: ubuntu-latest - outputs: - tag_name: ${{ steps.tag.outputs.name }} - should_release: ${{ steps.tag.outputs.should_release }} - - steps: - - name: Checkout with full history - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - with: - fetch-depth: 0 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER=$(git rev-list --count HEAD) - SHORT_HASH=$(git rev-parse --short=7 HEAD) - CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" - SHOULD_RELEASE="false" - - echo "Raw values:" - echo "BUILD_NUMBER: $BUILD_NUMBER" - echo "SHORT_HASH: $SHORT_HASH" - echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" - echo "CUSTOM_TAG: $CUSTOM_TAG" - - if [[ "${{ github.ref_type }}" == "tag" ]]; then - echo "Using pushed tag name" - TAG_NAME="${{ github.ref_name }}" - SHOULD_RELEASE="true" - elif [[ -n "$CUSTOM_TAG" ]]; then - echo "Using custom tag" - TAG_NAME="${CUSTOM_TAG}" - SHOULD_RELEASE="true" - elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then - echo "Manual release requested" - SHOULD_RELEASE="true" - TAG_NAME="b${BUILD_NUMBER}" - elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "Using master branch format" - TAG_NAME="b${BUILD_NUMBER}" - SHOULD_RELEASE="false" - else - echo "Using non-master branch format" - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" - SHOULD_RELEASE="false" - fi - - echo "Final tag name: $TAG_NAME" - echo "Should release: $SHOULD_RELEASE" - echo "name=$TAG_NAME" >> $GITHUB_OUTPUT - echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT - - ubuntu-cpu: - runs-on: ${{ matrix.os }} - needs: determine-tag - if: ${{ needs.determine-tag.outputs.should_release == 'true' }} - - strategy: - matrix: - include: - - build: x64 - os: ubuntu-22.04 - - build: arm64 - os: ubuntu-22.04-arm - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: release-${{ matrix.os }}-cpu - evict-old-files: 1d - - - name: Dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake - - - name: Build - run: | - cmake -B build \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_RPATH='$ORIGIN' \ - -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ - -DGGML_BACKEND_DL=ON \ - -DGGML_NATIVE=OFF \ - ${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }} - cmake --build build --config Release -j $(nproc) - - - name: Pack artifacts - run: | - cp LICENSE ./build/bin/ - tar -czvf whisper-bin-ubuntu-${{ matrix.build }}.tar.gz \ - --transform "s,^\.,whisper-bin-ubuntu-${{ matrix.build }}," \ - -C ./build/bin . - - - name: Upload artifacts - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz - name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz - - windows: - runs-on: windows-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - sdl2: [ON] - include: - - arch: Win32 - s2arc: x86 - jnaPath: win32-x86 - - arch: x64 - s2arc: x64 - jnaPath: win32-x86-64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DBUILD_SHARED_LIBS=ON - -DWHISPER_SDL2=${{ matrix.sdl2 }} - -DGGML_NATIVE=OFF - ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Upload SDL2.dll - if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: ${{ matrix.s2arc }}_SDL2.dll - path: build/bin/${{ matrix.build }}/SDL2.dll - - - name: Upload whisper dll - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: whisper_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/whisper.dll - - - name: Upload ggml dll - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: ggml_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml.dll - overwrite: true - - - name: Upload ggml base dll - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: ggml_base_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-base.dll - - - name: Upload ggml cpu dll - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: ggml_cpu_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-cpu.dll - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: whisper-bin-${{ matrix.arch }}.zip - path: whisper-bin-${{ matrix.arch }}.zip - - windows-blas: - runs-on: windows-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - blas: [ON] - sdl2: [ON] - blasver: [0.3.29] - include: - - arch: Win32 - s2arc: x86 - blasfile: x86 - - arch: x64 - s2arc: x64 - blasfile: x64_64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - - - name: Install OpenBLAS and pkgconfiglite - if: matrix.blas == 'ON' - run: | - Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" - Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" - choco install pkgconfiglite - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DGGML_BLAS=${{ matrix.blas }} - -DGGML_BLAS_VENDOR=OpenBLAS - -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" - -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" - -DWHISPER_SDL2=${{ matrix.sdl2 }} - -DGGML_NATIVE=OFF - ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy openblas.dll - if: matrix.blas == 'ON' - run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: whisper-blas-bin-${{ matrix.arch }}.zip - path: whisper-blas-bin-${{ matrix.arch }}.zip - - windows-cublas: - runs-on: windows-2022 - needs: determine-tag - strategy: - fail-fast: false - matrix: - build: [Release] - arch: [x64] - cublas: [ON] - sdl2: [ON] - cuda-toolkit: [12.4.0, 11.8.0] - include: - - arch: x64 - sdl2: ON - sdl2_ver: 2.28.5 - steps: - - name: Clone repository - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Install Ninja - id: install_ninja - run: | - choco install ninja - - - name: Install ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - evict-old-files: 5d - - - name: Install Cuda Toolkit 11.8.0 - if: ${{ matrix.cuda-toolkit == '11.8.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "11.8.89" - $NVCC_VER = "11.8.89" - $NVRTC_VER = "11.8.89" - $CUBLAS_VER = "11.8.1.74" - $NVTX_VER = "11.8.86" - $VS_VER = "11.8.86" - $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Install Cuda Toolkit 12.4.0 - if: ${{ matrix.cuda-toolkit == '12.4.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $NVPROF_VER = "12.4.128" - $CCCL_VER = "12.4.127" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - - - name: Install 7-Zip - run: choco install 7zip -y - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip - 7z x sdl2.zip - echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append - echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt - - - name: Install cmake - run: choco install cmake - - - name: Build Project - shell: cmd - run: | - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake --version - where cmake - if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR - ) else ( - set CUDA_FLAGS= - ) - cmake -S . -B build -G "Ninja Multi-Config" ^ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ - -DGGML_CUDA=${{ matrix.cublas }} ^ - -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ - -DSDL2_DIR="%SDL2_DIR%" ^ - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^ - -DGGML_BACKEND_DL=ON ^ - -DGGML_NATIVE=OFF ^ - -DGGML_CPU_ALL_VARIANTS=ON - set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 - cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - - name: Check ccache status after build - run: | - ccache --show-stats - - - name: Copy CUDA DLLs - run: | - Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | - Copy-Item -Destination "build/bin/${{ matrix.build }}" - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - - ios-xcode-build: - runs-on: macos-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - - steps: - - name: Checkout code - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - - - name: Configure - run: | - cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin - mkdir models/ggml-base.en-encoder.mlmodelc - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ - -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO - - - name: xcodebuild for swift package - id: xcodebuild - run: | - ./build-xcframework.sh - - - name: Build objc example - run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Build swiftui example - run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Pack artifacts - id: pack_artifacts - run: | - zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework - - - name: Upload artifacts - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 - with: - path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - - release: - if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} - - runs-on: ubuntu-latest - - needs: - - determine-tag - - ubuntu-cpu - - ios-xcode-build - - windows - - windows-blas - - windows-cublas - - steps: - - name: Clone - id: checkout - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - with: - fetch-depth: 0 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.21 - with: - key: release - evict-old-files: 1d - - # Downloads all the artifacts from the previous jobs - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7 - with: - path: ./artifact - - - name: Move artifacts - id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release && mv ./artifact/*/*.tar.gz ./artifact/release 2>/dev/null || true - - - name: Create release - id: create_release - uses: ggml-org/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ needs.determine-tag.outputs.tag_name }} - prerelease: ${{ github.event.inputs.pre_release_tag != '' }} - draft: true - - - name: Upload release - id: upload_release - uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { - if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) - }); - } - }